feat(tools/neo4j-schema): add neo4j-schema tool (#978)

This pull request introduces a new tool, `neo4j-schema`, for extracting
and processing comprehensive schema information from Neo4j databases. It
includes updates to the documentation, implementation of caching
mechanisms, helper utilities for schema transformation, and
corresponding unit tests. The most important changes are grouped by
theme below:

### Tool Integration:
- **`cmd/root.go`**: Added import for the new `neo4j-schema` tool to
integrate it into the application.

### Documentation:
- **`docs/en/resources/tools/neo4j/neo4j-schema.md`**: Added detailed
documentation for the `neo4j-schema` tool, describing its functionality,
caching behavior, and usage examples.

### Caching Implementation:
- **`internal/tools/neo4j/neo4jschema/cache/cache.go`**: Implemented a
thread-safe, in-memory cache with expiration and optional janitor for
cleaning expired items.

### Unit Tests:
- **`internal/tools/neo4j/neo4jschema/cache/cache_test.go`**: Added
comprehensive tests for the caching system, including functionality for
setting, retrieving, expiration, janitor cleanup, and concurrent access.

### Helper Utilities:
- **`internal/tools/neo4j/neo4jschema/helpers/helpers.go`**: Added
utility functions for processing schema data, including support for APOC
and native Cypher queries, and converting raw query results into
structured formats.

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
nester-neo4j
2025-07-24 20:40:16 -04:00
committed by GitHub
parent 7e7d55c5d1
commit be7db3dff2
11 changed files with 2303 additions and 5 deletions

View File

@@ -85,6 +85,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"

View File

@@ -0,0 +1,42 @@
---
title: "neo4j-schema"
type: "docs"
weight: 1
description: >
A "neo4j-schema" tool extracts a comprehensive schema from a Neo4j
database.
aliases:
- /resources/tools/neo4j-schema
---
## About
A `neo4j-schema` tool connects to a Neo4j database and extracts its complete schema information. It runs multiple queries concurrently to efficiently gather details about node labels, relationships, properties, constraints, and indexes.
The tool automatically detects if the APOC (Awesome Procedures on Cypher) library is available. If so, it uses APOC procedures like `apoc.meta.schema` for a highly detailed overview of the database structure; otherwise, it falls back to using native Cypher queries.
The extracted schema is **cached** to improve performance for subsequent requests. The output is a structured JSON object containing all the schema details, which can be invaluable for providing database context to an LLM. This tool is compatible with a `neo4j` source and takes no parameters.
## Example
```yaml
tools:
get_movie_db_schema:
kind: neo4j-schema
source: my-neo4j-movies-instance
description: |
Use this tool to get the full schema of the movie database.
This provides information on all available node labels (like Movie, Person),
relationships (like ACTED_IN), and the properties on each.
This tool takes no parameters.
# Optional configuration to cache the schema for 2 hours
cacheExpireMinutes: 120
```
## Reference
| **field** | **type** | **required** | **description** |
|---------------------|:----------:|:------------:|-------------------------------------------------------------------------------------------------|
| kind | string | true | Must be `neo4j-db-schema`. |
| source | string | true | Name of the source the schema should be extracted from. |
| description | string | true | Description of the tool that is passed to the LLM. |
| cacheExpireMinutes | integer | false | Cache expiration time in minutes. Defaults to 60. |

View File

@@ -15,10 +15,11 @@
package mongodbaggregate_test
import (
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
"strings"
"testing"
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"

View File

@@ -0,0 +1,204 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Package cache provides a simple, thread-safe, in-memory key-value store.
It features item expiration and an optional background process (janitor) that
periodically removes expired items.
*/
package cache
import (
"sync"
"time"
)
const (
// DefaultJanitorInterval is the default interval at which the janitor
// runs to clean up expired cache items.
DefaultJanitorInterval = 1 * time.Minute
// DefaultExpiration is the default time-to-live for a cache item.
// Note: This constant is defined but not used in the current implementation,
// as expiration is set on a per-item basis.
DefaultExpiration = 60
)
// CacheItem represents a value stored in the cache, along with its expiration time.
type CacheItem struct {
Value any // The actual value being stored.
Expiration int64 // The time when the item expires, as a Unix nano timestamp. 0 means no expiration.
}
// isExpired checks if the cache item has passed its expiration time.
// It returns true if the item is expired, and false otherwise.
func (item CacheItem) isExpired() bool {
// If Expiration is 0, the item is considered to never expire.
if item.Expiration == 0 {
return false
}
return time.Now().UnixNano() > item.Expiration
}
// Cache is a thread-safe, in-memory key-value store with self-cleaning capabilities.
type Cache struct {
items map[string]CacheItem // The underlying map that stores the cache items.
mu sync.RWMutex // A read/write mutex to ensure thread safety for concurrent access.
stop chan struct{} // A channel used to signal the janitor goroutine to stop.
}
// NewCache creates and returns a new Cache instance.
// The janitor for cleaning up expired items is not started by default.
// Use the WithJanitor method to start the cleanup process.
//
// Example:
//
// c := cache.NewCache()
// c.Set("myKey", "myValue", 5*time.Minute)
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// WithJanitor starts a background goroutine (janitor) that periodically cleans up
// expired items from the cache. If a janitor is already running, it will be
// stopped and a new one will be started with the specified interval.
//
// The interval parameter defines how often the janitor should run. If a non-positive
// interval is provided, it defaults to DefaultJanitorInterval (1 minute).
//
// It returns a pointer to the Cache to allow for method chaining.
//
// Example:
//
// // Create a cache that cleans itself every 10 minutes.
// c := cache.NewCache().WithJanitor(10 * time.Minute)
// defer c.Stop() // It's important to stop the janitor when the cache is no longer needed.
func (c *Cache) WithJanitor(interval time.Duration) *Cache {
c.mu.Lock()
defer c.mu.Unlock()
if c.stop != nil {
// If a janitor is already running, stop it before starting a new one.
close(c.stop)
}
c.stop = make(chan struct{})
// Use the default interval if an invalid one is provided.
if interval <= 0 {
interval = DefaultJanitorInterval
}
// Start the janitor in a new goroutine.
go c.janitor(interval, c.stop)
return c
}
// Get retrieves an item from the cache by its key.
// It returns the item's value and a boolean. The boolean is true if the key
// was found and the item has not expired. Otherwise, it is false.
//
// Example:
//
// v, found := c.Get("myKey")
// if found {
// fmt.Printf("Found value: %v\n", v)
// } else {
// fmt.Println("Key not found or expired.")
// }
func (c *Cache) Get(key string) (any, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, found := c.items[key]
// Return false if the item is not found or if it is found but has expired.
if !found || item.isExpired() {
return nil, false
}
return item.Value, true
}
// Set adds an item to the cache, replacing any existing item with the same key.
//
// The `ttl` (time-to-live) parameter specifies how long the item should remain
// in the cache. If `ttl` is positive, the item will expire after that duration.
// If `ttl` is zero or negative, the item will never expire.
//
// Example:
//
// // Add a key that expires in 5 minutes.
// c.Set("sessionToken", "xyz123", 5*time.Minute)
//
// // Add a key that never expires.
// c.Set("appConfig", "configValue", 0)
func (c *Cache) Set(key string, value any, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
var expiration int64
// Calculate the expiration time only if ttl is positive.
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items[key] = CacheItem{
Value: value,
Expiration: expiration,
}
}
// Stop terminates the background janitor goroutine.
// It is safe to call Stop even if the janitor was never started or has already
// been stopped. This is useful for cleaning up resources.
func (c *Cache) Stop() {
c.mu.Lock()
defer c.mu.Unlock()
if c.stop != nil {
close(c.stop)
c.stop = nil
}
}
// janitor is the background cleanup worker. It runs in a separate goroutine.
// It uses a time.Ticker to periodically trigger the deletion of expired items.
func (c *Cache) janitor(interval time.Duration, stopCh chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Time to clean up expired items.
c.deleteExpired()
case <-stopCh:
// Stop signal received, exit the goroutine.
return
}
}
}
// deleteExpired scans the cache and removes all items that have expired.
// This function acquires a write lock on the cache to ensure safe mutation.
func (c *Cache) deleteExpired() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if v.isExpired() {
delete(c.items, k)
}
}
}

View File

@@ -0,0 +1,170 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"sync"
"testing"
"time"
)
// TestCache_SetAndGet verifies the basic functionality of setting a value
// and immediately retrieving it.
func TestCache_SetAndGet(t *testing.T) {
cache := NewCache()
defer cache.Stop()
key := "testKey"
value := "testValue"
cache.Set(key, value, 1*time.Minute)
retrievedValue, found := cache.Get(key)
if !found {
t.Errorf("Expected to find key %q, but it was not found", key)
}
if retrievedValue != value {
t.Errorf("Expected value %q, but got %q", value, retrievedValue)
}
}
// TestCache_GetExpired tests that an item is not retrievable after it has expired.
func TestCache_GetExpired(t *testing.T) {
cache := NewCache()
defer cache.Stop()
key := "expiredKey"
value := "expiredValue"
// Set an item with a very short TTL.
cache.Set(key, value, 1*time.Millisecond)
time.Sleep(2 * time.Millisecond) // Wait for the item to expire.
// Attempt to get the expired item.
_, found := cache.Get(key)
if found {
t.Errorf("Expected key %q to be expired, but it was found", key)
}
}
// TestCache_SetNoExpiration ensures that an item with a TTL of 0 or less
// does not expire.
func TestCache_SetNoExpiration(t *testing.T) {
cache := NewCache()
defer cache.Stop()
key := "noExpireKey"
value := "noExpireValue"
cache.Set(key, value, 0) // Setting with 0 should mean no expiration.
time.Sleep(5 * time.Millisecond)
retrievedValue, found := cache.Get(key)
if !found {
t.Errorf("Expected to find key %q, but it was not found", key)
}
if retrievedValue != value {
t.Errorf("Expected value %q, but got %q", value, retrievedValue)
}
}
// TestCache_Janitor verifies that the janitor goroutine automatically removes
// expired items from the cache.
func TestCache_Janitor(t *testing.T) {
// Initialize cache with a very short janitor interval for quick testing.
cache := NewCache().WithJanitor(10 * time.Millisecond)
defer cache.Stop()
expiredKey := "expired"
activeKey := "active"
// Set one item that will expire and one that will not.
cache.Set(expiredKey, "value", 1*time.Millisecond)
cache.Set(activeKey, "value", 1*time.Hour)
// Wait longer than the janitor interval to ensure it has a chance to run.
time.Sleep(20 * time.Millisecond)
// Check that the expired key has been removed.
_, found := cache.Get(expiredKey)
if found {
t.Errorf("Expected janitor to clean up expired key %q, but it was found", expiredKey)
}
// Check that the active key is still present.
_, found = cache.Get(activeKey)
if !found {
t.Errorf("Expected active key %q to be present, but it was not found", activeKey)
}
}
// TestCache_Stop ensures that calling the Stop method does not cause a panic,
// regardless of whether the janitor is running or not. It also tests idempotency.
func TestCache_Stop(t *testing.T) {
t.Run("Stop without janitor", func(t *testing.T) {
cache := NewCache()
// Test that calling Stop multiple times on a cache without a janitor is safe.
cache.Stop()
cache.Stop()
})
t.Run("Stop with janitor", func(t *testing.T) {
cache := NewCache().WithJanitor(1 * time.Minute)
// Test that calling Stop multiple times on a cache with a janitor is safe.
cache.Stop()
cache.Stop()
})
}
// TestCache_Concurrent performs a stress test on the cache with concurrent
// reads and writes to check for race conditions.
func TestCache_Concurrent(t *testing.T) {
cache := NewCache().WithJanitor(100 * time.Millisecond)
defer cache.Stop()
var wg sync.WaitGroup
numGoroutines := 100
numOperations := 1000
// Start concurrent writer goroutines.
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(g int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
key := string(rune(g*numOperations + j))
value := g*numOperations + j
cache.Set(key, value, 10*time.Second)
}
}(i)
}
// Start concurrent reader goroutines.
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(g int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
key := string(rune(g*numOperations + j))
cache.Get(key) // We don't check the result, just that access is safe.
}
}(i)
}
// Wait for all goroutines to complete. If a race condition exists, the Go
// race detector (`go test -race`) will likely catch it.
wg.Wait()
}

View File

@@ -0,0 +1,291 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package helpers provides utility functions for transforming and processing Neo4j
// schema data. It includes functions for converting raw query results from both
// APOC and native Cypher queries into a standardized, structured format.
package helpers
import (
"fmt"
"sort"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types"
)
// ConvertToStringSlice converts a slice of any type to a slice of strings.
// It uses fmt.Sprintf to perform the conversion for each element.
// Example:
//
// input: []any{"user", 123, true}
// output: []string{"user", "123", "true"}
func ConvertToStringSlice(slice []any) []string {
result := make([]string, len(slice))
for i, v := range slice {
result[i] = fmt.Sprintf("%v", v)
}
return result
}
// GetStringValue safely converts any value to its string representation.
// If the input value is nil, it returns an empty string.
func GetStringValue(val any) string {
if val == nil {
return ""
}
return fmt.Sprintf("%v", val)
}
// MapToAPOCSchema converts a raw map from a Cypher query into a structured
// APOCSchemaResult. This is a workaround for database drivers that may return
// complex nested structures as `map[string]any` instead of unmarshalling
// directly into a struct. It achieves this by marshalling the map to YAML and
// then unmarshalling into the target struct.
func MapToAPOCSchema(schemaMap map[string]any) (*types.APOCSchemaResult, error) {
schemaBytes, err := yaml.Marshal(schemaMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal schema map: %w", err)
}
var entities map[string]types.APOCEntity
if err = yaml.Unmarshal(schemaBytes, &entities); err != nil {
return nil, fmt.Errorf("failed to unmarshal schema map into entities: %w", err)
}
return &types.APOCSchemaResult{Value: entities}, nil
}
// ProcessAPOCSchema transforms the nested result from the `apoc.meta.schema()`
// procedure into flat lists of node labels and relationships, along with
// aggregated database statistics. It iterates through entities, processes nodes,
// and extracts outgoing relationship information nested within those nodes.
func ProcessAPOCSchema(apocSchema *types.APOCSchemaResult) ([]types.NodeLabel, []types.Relationship, *types.Statistics) {
var nodeLabels []types.NodeLabel
relMap := make(map[string]*types.Relationship)
stats := &types.Statistics{
NodesByLabel: make(map[string]int64),
RelationshipsByType: make(map[string]int64),
PropertiesByLabel: make(map[string]int64),
PropertiesByRelType: make(map[string]int64),
}
for name, entity := range apocSchema.Value {
// We only process top-level entities of type "node". Relationship info is
// derived from the "relationships" field within each node entity.
if entity.Type != "node" {
continue
}
nodeLabel := types.NodeLabel{
Name: name,
Count: entity.Count,
Properties: extractAPOCProperties(entity.Properties),
}
nodeLabels = append(nodeLabels, nodeLabel)
// Aggregate statistics for the node.
stats.NodesByLabel[name] = entity.Count
stats.TotalNodes += entity.Count
propCount := int64(len(nodeLabel.Properties))
stats.PropertiesByLabel[name] = propCount
stats.TotalProperties += propCount * entity.Count
// Extract relationship information from the node.
for relName, relInfo := range entity.Relationships {
// Only process outgoing relationships to avoid double-counting.
if relInfo.Direction != "out" {
continue
}
rel, exists := relMap[relName]
if !exists {
rel = &types.Relationship{
Type: relName,
Properties: extractAPOCProperties(relInfo.Properties),
}
if len(relInfo.Labels) > 0 {
rel.EndNode = relInfo.Labels[0]
}
rel.StartNode = name
relMap[relName] = rel
}
rel.Count += relInfo.Count
}
}
// Consolidate the relationships from the map into a slice and update stats.
relationships := make([]types.Relationship, 0, len(relMap))
for _, rel := range relMap {
relationships = append(relationships, *rel)
stats.RelationshipsByType[rel.Type] = rel.Count
stats.TotalRelationships += rel.Count
propCount := int64(len(rel.Properties))
stats.PropertiesByRelType[rel.Type] = propCount
stats.TotalProperties += propCount * rel.Count
}
sortAndClean(nodeLabels, relationships, stats)
// Set empty maps and lists to nil for cleaner output.
if len(nodeLabels) == 0 {
nodeLabels = nil
}
if len(relationships) == 0 {
relationships = nil
}
return nodeLabels, relationships, stats
}
// ProcessNonAPOCSchema serves as an alternative to ProcessAPOCSchema for environments
// where APOC procedures are not available. It converts schema data gathered from
// multiple separate, native Cypher queries (providing node counts, property maps, etc.)
// into the same standardized, structured format.
func ProcessNonAPOCSchema(
nodeCounts map[string]int64,
nodePropsMap map[string]map[string]map[string]bool,
relCounts map[string]int64,
relPropsMap map[string]map[string]map[string]bool,
relConnectivity map[string]types.RelConnectivityInfo,
) ([]types.NodeLabel, []types.Relationship, *types.Statistics) {
stats := &types.Statistics{
NodesByLabel: make(map[string]int64),
RelationshipsByType: make(map[string]int64),
PropertiesByLabel: make(map[string]int64),
PropertiesByRelType: make(map[string]int64),
}
// Process node information.
nodeLabels := make([]types.NodeLabel, 0, len(nodeCounts))
for label, count := range nodeCounts {
properties := make([]types.PropertyInfo, 0)
if props, ok := nodePropsMap[label]; ok {
for key, typeSet := range props {
typeList := make([]string, 0, len(typeSet))
for tp := range typeSet {
typeList = append(typeList, tp)
}
sort.Strings(typeList)
properties = append(properties, types.PropertyInfo{Name: key, Types: typeList})
}
}
sort.Slice(properties, func(i, j int) bool { return properties[i].Name < properties[j].Name })
nodeLabels = append(nodeLabels, types.NodeLabel{Name: label, Count: count, Properties: properties})
// Aggregate node statistics.
stats.NodesByLabel[label] = count
stats.TotalNodes += count
propCount := int64(len(properties))
stats.PropertiesByLabel[label] = propCount
stats.TotalProperties += propCount * count
}
// Process relationship information.
relationships := make([]types.Relationship, 0, len(relCounts))
for relType, count := range relCounts {
properties := make([]types.PropertyInfo, 0)
if props, ok := relPropsMap[relType]; ok {
for key, typeSet := range props {
typeList := make([]string, 0, len(typeSet))
for tp := range typeSet {
typeList = append(typeList, tp)
}
sort.Strings(typeList)
properties = append(properties, types.PropertyInfo{Name: key, Types: typeList})
}
}
sort.Slice(properties, func(i, j int) bool { return properties[i].Name < properties[j].Name })
conn := relConnectivity[relType]
relationships = append(relationships, types.Relationship{
Type: relType,
Count: count,
StartNode: conn.StartNode,
EndNode: conn.EndNode,
Properties: properties,
})
// Aggregate relationship statistics.
stats.RelationshipsByType[relType] = count
stats.TotalRelationships += count
propCount := int64(len(properties))
stats.PropertiesByRelType[relType] = propCount
stats.TotalProperties += propCount * count
}
sortAndClean(nodeLabels, relationships, stats)
// Set empty maps and lists to nil for cleaner output.
if len(nodeLabels) == 0 {
nodeLabels = nil
}
if len(relationships) == 0 {
relationships = nil
}
return nodeLabels, relationships, stats
}
// extractAPOCProperties is a helper that converts a map of APOC property
// information into a slice of standardized PropertyInfo structs. The resulting
// slice is sorted by property name for consistent ordering.
func extractAPOCProperties(props map[string]types.APOCProperty) []types.PropertyInfo {
properties := make([]types.PropertyInfo, 0, len(props))
for name, info := range props {
properties = append(properties, types.PropertyInfo{
Name: name,
Types: []string{info.Type},
Indexed: info.Indexed,
Unique: info.Unique,
Mandatory: info.Existence,
})
}
sort.Slice(properties, func(i, j int) bool {
return properties[i].Name < properties[j].Name
})
return properties
}
// sortAndClean performs final processing on the schema data. It sorts node and
// relationship slices for consistent output, primarily by count (descending) and
// secondarily by name/type. It also sets any empty maps in the statistics
// struct to nil, which can simplify downstream serialization (e.g., omitting
// empty fields in JSON).
func sortAndClean(nodeLabels []types.NodeLabel, relationships []types.Relationship, stats *types.Statistics) {
// Sort nodes by count (desc) then name (asc).
sort.Slice(nodeLabels, func(i, j int) bool {
if nodeLabels[i].Count != nodeLabels[j].Count {
return nodeLabels[i].Count > nodeLabels[j].Count
}
return nodeLabels[i].Name < nodeLabels[j].Name
})
// Sort relationships by count (desc) then type (asc).
sort.Slice(relationships, func(i, j int) bool {
if relationships[i].Count != relationships[j].Count {
return relationships[i].Count > relationships[j].Count
}
return relationships[i].Type < relationships[j].Type
})
// Nil out empty maps for cleaner output.
if len(stats.NodesByLabel) == 0 {
stats.NodesByLabel = nil
}
if len(stats.RelationshipsByType) == 0 {
stats.RelationshipsByType = nil
}
if len(stats.PropertiesByLabel) == 0 {
stats.PropertiesByLabel = nil
}
if len(stats.PropertiesByRelType) == 0 {
stats.PropertiesByRelType = nil
}
}

View File

@@ -0,0 +1,384 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package helpers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types"
)
func TestHelperFunctions(t *testing.T) {
t.Run("ConvertToStringSlice", func(t *testing.T) {
tests := []struct {
name string
input []any
want []string
}{
{
name: "empty slice",
input: []any{},
want: []string{},
},
{
name: "string values",
input: []any{"a", "b", "c"},
want: []string{"a", "b", "c"},
},
{
name: "mixed types",
input: []any{"string", 123, true, 45.67},
want: []string{"string", "123", "true", "45.67"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ConvertToStringSlice(tt.input)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("ConvertToStringSlice() mismatch (-want +got):\n%s", diff)
}
})
}
})
t.Run("GetStringValue", func(t *testing.T) {
tests := []struct {
name string
input any
want string
}{
{
name: "nil value",
input: nil,
want: "",
},
{
name: "string value",
input: "test",
want: "test",
},
{
name: "int value",
input: 42,
want: "42",
},
{
name: "bool value",
input: true,
want: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetStringValue(tt.input)
if got != tt.want {
t.Errorf("GetStringValue() got %q, want %q", got, tt.want)
}
})
}
})
}
func TestMapToAPOCSchema(t *testing.T) {
tests := []struct {
name string
input map[string]any
want *types.APOCSchemaResult
wantErr bool
}{
{
name: "simple node schema",
input: map[string]any{
"Person": map[string]any{
"type": "node",
"count": int64(150),
"properties": map[string]any{
"name": map[string]any{
"type": "STRING",
"unique": false,
"indexed": true,
"existence": false,
},
},
},
},
want: &types.APOCSchemaResult{
Value: map[string]types.APOCEntity{
"Person": {
Type: "node",
Count: 150,
Properties: map[string]types.APOCProperty{
"name": {
Type: "STRING",
Unique: false,
Indexed: true,
Existence: false,
},
},
},
},
},
wantErr: false,
},
{
name: "empty input",
input: map[string]any{},
want: &types.APOCSchemaResult{Value: map[string]types.APOCEntity{}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MapToAPOCSchema(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("MapToAPOCSchema() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("MapToAPOCSchema() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestProcessAPOCSchema(t *testing.T) {
tests := []struct {
name string
input *types.APOCSchemaResult
wantNodes []types.NodeLabel
wantRels []types.Relationship
wantStats *types.Statistics
statsAreEmpty bool
}{
{
name: "empty schema",
input: &types.APOCSchemaResult{
Value: map[string]types.APOCEntity{},
},
wantNodes: nil,
wantRels: nil,
statsAreEmpty: true,
},
{
name: "simple node only",
input: &types.APOCSchemaResult{
Value: map[string]types.APOCEntity{
"Person": {
Type: "node",
Count: 100,
Properties: map[string]types.APOCProperty{
"name": {Type: "STRING", Indexed: true},
"age": {Type: "INTEGER"},
},
},
},
},
wantNodes: []types.NodeLabel{
{
Name: "Person",
Count: 100,
Properties: []types.PropertyInfo{
{Name: "age", Types: []string{"INTEGER"}},
{Name: "name", Types: []string{"STRING"}, Indexed: true},
},
},
},
wantRels: nil,
wantStats: &types.Statistics{
NodesByLabel: map[string]int64{"Person": 100},
PropertiesByLabel: map[string]int64{"Person": 2},
TotalNodes: 100,
TotalProperties: 200,
},
},
{
name: "nodes and relationships",
input: &types.APOCSchemaResult{
Value: map[string]types.APOCEntity{
"Person": {
Type: "node",
Count: 100,
Properties: map[string]types.APOCProperty{
"name": {Type: "STRING", Unique: true, Indexed: true, Existence: true},
},
Relationships: map[string]types.APOCRelationshipInfo{
"KNOWS": {
Direction: "out",
Count: 50,
Labels: []string{"Person"},
Properties: map[string]types.APOCProperty{
"since": {Type: "INTEGER"},
},
},
},
},
"Post": {
Type: "node",
Count: 200,
Properties: map[string]types.APOCProperty{"content": {Type: "STRING"}},
},
"FOLLOWS": {Type: "relationship", Count: 80},
},
},
wantNodes: []types.NodeLabel{
{
Name: "Post",
Count: 200,
Properties: []types.PropertyInfo{
{Name: "content", Types: []string{"STRING"}},
},
},
{
Name: "Person",
Count: 100,
Properties: []types.PropertyInfo{
{Name: "name", Types: []string{"STRING"}, Unique: true, Indexed: true, Mandatory: true},
},
},
},
wantRels: []types.Relationship{
{
Type: "KNOWS",
StartNode: "Person",
EndNode: "Person",
Count: 50,
Properties: []types.PropertyInfo{
{Name: "since", Types: []string{"INTEGER"}},
},
},
},
wantStats: &types.Statistics{
NodesByLabel: map[string]int64{"Person": 100, "Post": 200},
RelationshipsByType: map[string]int64{"KNOWS": 50},
PropertiesByLabel: map[string]int64{"Person": 1, "Post": 1},
PropertiesByRelType: map[string]int64{"KNOWS": 1},
TotalNodes: 300,
TotalRelationships: 50,
TotalProperties: 350, // (100*1 + 200*1) for nodes + (50*1) for rels
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNodes, gotRels, gotStats := ProcessAPOCSchema(tt.input)
if diff := cmp.Diff(tt.wantNodes, gotNodes); diff != "" {
t.Errorf("ProcessAPOCSchema() node labels mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.wantRels, gotRels); diff != "" {
t.Errorf("ProcessAPOCSchema() relationships mismatch (-want +got):\n%s", diff)
}
if tt.statsAreEmpty {
tt.wantStats = &types.Statistics{}
}
if diff := cmp.Diff(tt.wantStats, gotStats); diff != "" {
t.Errorf("ProcessAPOCSchema() statistics mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestProcessNonAPOCSchema(t *testing.T) {
t.Run("full schema processing", func(t *testing.T) {
nodeCounts := map[string]int64{"Person": 10, "City": 5}
nodePropsMap := map[string]map[string]map[string]bool{
"Person": {"name": {"STRING": true}, "age": {"INTEGER": true}},
"City": {"name": {"STRING": true, "TEXT": true}},
}
relCounts := map[string]int64{"LIVES_IN": 8}
relPropsMap := map[string]map[string]map[string]bool{
"LIVES_IN": {"since": {"DATE": true}},
}
relConnectivity := map[string]types.RelConnectivityInfo{
"LIVES_IN": {StartNode: "Person", EndNode: "City", Count: 8},
}
wantNodes := []types.NodeLabel{
{
Name: "Person",
Count: 10,
Properties: []types.PropertyInfo{
{Name: "age", Types: []string{"INTEGER"}},
{Name: "name", Types: []string{"STRING"}},
},
},
{
Name: "City",
Count: 5,
Properties: []types.PropertyInfo{
{Name: "name", Types: []string{"STRING", "TEXT"}},
},
},
}
wantRels := []types.Relationship{
{
Type: "LIVES_IN",
Count: 8,
StartNode: "Person",
EndNode: "City",
Properties: []types.PropertyInfo{
{Name: "since", Types: []string{"DATE"}},
},
},
}
wantStats := &types.Statistics{
TotalNodes: 15,
TotalRelationships: 8,
TotalProperties: 33, // (10*2 + 5*1) for nodes + (8*1) for rels
NodesByLabel: map[string]int64{"Person": 10, "City": 5},
RelationshipsByType: map[string]int64{"LIVES_IN": 8},
PropertiesByLabel: map[string]int64{"Person": 2, "City": 1},
PropertiesByRelType: map[string]int64{"LIVES_IN": 1},
}
gotNodes, gotRels, gotStats := ProcessNonAPOCSchema(nodeCounts, nodePropsMap, relCounts, relPropsMap, relConnectivity)
if diff := cmp.Diff(wantNodes, gotNodes); diff != "" {
t.Errorf("ProcessNonAPOCSchema() nodes mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(wantRels, gotRels); diff != "" {
t.Errorf("ProcessNonAPOCSchema() relationships mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(wantStats, gotStats); diff != "" {
t.Errorf("ProcessNonAPOCSchema() stats mismatch (-want +got):\n%s", diff)
}
})
t.Run("empty schema", func(t *testing.T) {
gotNodes, gotRels, gotStats := ProcessNonAPOCSchema(
map[string]int64{},
map[string]map[string]map[string]bool{},
map[string]int64{},
map[string]map[string]map[string]bool{},
map[string]types.RelConnectivityInfo{},
)
if len(gotNodes) != 0 {
t.Errorf("expected 0 nodes, got %d", len(gotNodes))
}
if len(gotRels) != 0 {
t.Errorf("expected 0 relationships, got %d", len(gotRels))
}
if diff := cmp.Diff(&types.Statistics{}, gotStats); diff != "" {
t.Errorf("ProcessNonAPOCSchema() stats mismatch (-want +got):\n%s", diff)
}
})
}

View File

@@ -0,0 +1,712 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package neo4jschema
import (
"context"
"fmt"
"sync"
"time"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
)
// kind defines the unique identifier for this tool.
const kind string = "neo4j-schema"
// init registers the tool with the application's tool registry when the package is initialized.
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
// newConfig decodes a YAML configuration into a Config struct.
// This function is called by the tool registry to create a new configuration object.
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
// compatibleSource defines the interface a data source must implement to be used by this tool.
// It ensures that the source can provide a Neo4j driver and database name.
type compatibleSource interface {
Neo4jDriver() neo4j.DriverWithContext
Neo4jDatabase() string
}
// Statically verify that our compatible source implementation is valid.
var _ compatibleSource = &neo4jsc.Source{}
// compatibleSources lists the kinds of sources that are compatible with this tool.
var compatibleSources = [...]string{neo4jsc.SourceKind}
// Config holds the configuration settings for the Neo4j schema tool.
// These settings are typically read from a YAML file.
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
CacheExpireMinutes *int `yaml:"cacheExpireMinutes,omitempty"` // Cache expiration time in minutes.
}
// Statically verify that Config implements the tools.ToolConfig interface.
var _ tools.ToolConfig = Config{}
// ToolConfigKind returns the kind of this tool configuration.
func (cfg Config) ToolConfigKind() string {
return kind
}
// Initialize sets up the tool with its dependencies and returns a ready-to-use Tool instance.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Verify that the specified source exists.
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// Verify the source is of a compatible kind.
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
parameters := tools.Parameters{}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: parameters.McpManifest(),
}
// Set a default cache expiration if not provided in the configuration.
if cfg.CacheExpireMinutes == nil {
defaultExpiration := cache.DefaultExpiration // Default to 60 minutes
cfg.CacheExpireMinutes = &defaultExpiration
}
// Finish tool setup by creating the Tool instance.
t := Tool{
Name: cfg.Name,
Kind: kind,
AuthRequired: cfg.AuthRequired,
Driver: s.Neo4jDriver(),
Database: s.Neo4jDatabase(),
cache: cache.NewCache(),
cacheExpireMinutes: cfg.CacheExpireMinutes,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// Statically verify that Tool implements the tools.Tool interface.
var _ tools.Tool = Tool{}
// Tool represents the Neo4j schema extraction tool.
// It holds the Neo4j driver, database information, and a cache for the schema.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Driver neo4j.DriverWithContext
Database string
cache *cache.Cache
cacheExpireMinutes *int
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// Invoke executes the tool's main logic: fetching the Neo4j schema.
// It first checks the cache for a valid schema before extracting it from the database.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
// Check if a valid schema is already in the cache.
if cachedSchema, ok := t.cache.Get("schema"); ok {
if schema, ok := cachedSchema.(*types.SchemaInfo); ok {
return schema, nil
}
}
// If not cached, extract the schema from the database.
schema, err := t.extractSchema(ctx)
if err != nil {
return nil, fmt.Errorf("failed to extract database schema: %w", err)
}
// Cache the newly extracted schema for future use.
expiration := time.Duration(*t.cacheExpireMinutes) * time.Minute
t.cache.Set("schema", schema, expiration)
return schema, nil
}
// ParseParams is a placeholder as this tool does not require input parameters.
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParamValues{}, nil
}
// Manifest returns the tool's manifest, which describes its purpose and parameters.
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
// McpManifest returns the machine-consumable manifest for the tool.
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
// Authorized checks if the tool is authorized to run based on the provided authentication services.
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
// checkAPOCProcedures verifies if essential APOC procedures are available in the database.
// It returns true only if all required procedures are found.
func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) {
proceduresToCheck := []string{"apoc.meta.schema", "apoc.meta.cypher.types"}
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
// This query efficiently counts how many of the specified procedures exist.
query := "SHOW PROCEDURES YIELD name WHERE name IN $procs RETURN count(name) AS procCount"
params := map[string]any{"procs": proceduresToCheck}
result, err := session.Run(ctx, query, params)
if err != nil {
return false, fmt.Errorf("failed to execute procedure check query: %w", err)
}
record, err := result.Single(ctx)
if err != nil {
return false, fmt.Errorf("failed to retrieve single result for procedure check: %w", err)
}
rawCount, found := record.Get("procCount")
if !found {
return false, fmt.Errorf("field 'procCount' not found in result record")
}
procCount, ok := rawCount.(int64)
if !ok {
return false, fmt.Errorf("expected 'procCount' to be of type int64, but got %T", rawCount)
}
// Return true only if the number of found procedures matches the number we were looking for.
return procCount == int64(len(proceduresToCheck)), nil
}
// extractSchema orchestrates the concurrent extraction of different parts of the database schema.
// It runs several extraction tasks in parallel for efficiency.
func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) {
schema := &types.SchemaInfo{}
var mu sync.Mutex
// Define the different schema extraction tasks.
tasks := []struct {
name string
fn func() error
}{
{
name: "database-info",
fn: func() error {
dbInfo, err := t.extractDatabaseInfo(ctx)
if err != nil {
return fmt.Errorf("failed to extract database info: %w", err)
}
mu.Lock()
defer mu.Unlock()
schema.DatabaseInfo = *dbInfo
return nil
},
},
{
name: "schema-extraction",
fn: func() error {
// Check if APOC procedures are available.
hasAPOC, err := t.checkAPOCProcedures(ctx)
if err != nil {
return fmt.Errorf("failed to check APOC procedures: %w", err)
}
var nodeLabels []types.NodeLabel
var relationships []types.Relationship
var stats *types.Statistics
// Use APOC if available for a more detailed schema; otherwise, use native queries.
if hasAPOC {
nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx)
} else {
nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, 100)
}
if err != nil {
return fmt.Errorf("failed to get schema: %w", err)
}
mu.Lock()
defer mu.Unlock()
schema.NodeLabels = nodeLabels
schema.Relationships = relationships
schema.Statistics = *stats
return nil
},
},
{
name: "constraints",
fn: func() error {
constraints, err := t.extractConstraints(ctx)
if err != nil {
return fmt.Errorf("failed to extract constraints: %w", err)
}
mu.Lock()
defer mu.Unlock()
schema.Constraints = constraints
return nil
},
},
{
name: "indexes",
fn: func() error {
indexes, err := t.extractIndexes(ctx)
if err != nil {
return fmt.Errorf("failed to extract indexes: %w", err)
}
mu.Lock()
defer mu.Unlock()
schema.Indexes = indexes
return nil
},
},
}
var wg sync.WaitGroup
errCh := make(chan error, len(tasks))
// Execute all tasks concurrently.
for _, task := range tasks {
wg.Add(1)
go func(task struct {
name string
fn func() error
}) {
defer wg.Done()
if err := task.fn(); err != nil {
errCh <- err
}
}(task)
}
wg.Wait()
close(errCh)
// Collect any errors that occurred during the concurrent tasks.
for err := range errCh {
if err != nil {
schema.Errors = append(schema.Errors, err.Error())
}
}
return schema, nil
}
// GetAPOCSchema extracts schema information using the APOC library, which provides detailed metadata.
func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) {
var nodeLabels []types.NodeLabel
var relationships []types.Relationship
stats := &types.Statistics{
NodesByLabel: make(map[string]int64),
RelationshipsByType: make(map[string]int64),
PropertiesByLabel: make(map[string]int64),
PropertiesByRelType: make(map[string]int64),
}
var mu sync.Mutex
var firstErr error
ctx, cancel := context.WithCancel(ctx)
defer cancel()
handleError := func(err error) {
mu.Lock()
defer mu.Unlock()
if firstErr == nil {
firstErr = err
cancel() // Cancel other operations on the first error.
}
}
tasks := []struct {
name string
fn func(session neo4j.SessionWithContext) error
}{
{
name: "apoc-schema",
fn: func(session neo4j.SessionWithContext) error {
result, err := session.Run(ctx, "CALL apoc.meta.schema({sample: 10}) YIELD value RETURN value", nil)
if err != nil {
return fmt.Errorf("failed to run APOC schema query: %w", err)
}
if !result.Next(ctx) {
return fmt.Errorf("no results from APOC schema query")
}
schemaMap, ok := result.Record().Values[0].(map[string]any)
if !ok {
return fmt.Errorf("unexpected result format from APOC schema query: %T", result.Record().Values[0])
}
apocSchema, err := helpers.MapToAPOCSchema(schemaMap)
if err != nil {
return fmt.Errorf("failed to convert schema map to APOCSchemaResult: %w", err)
}
nodes, _, apocStats := helpers.ProcessAPOCSchema(apocSchema)
mu.Lock()
defer mu.Unlock()
nodeLabels = nodes
stats.TotalNodes = apocStats.TotalNodes
stats.TotalProperties += apocStats.TotalProperties
stats.NodesByLabel = apocStats.NodesByLabel
stats.PropertiesByLabel = apocStats.PropertiesByLabel
return nil
},
},
{
name: "apoc-relationships",
fn: func(session neo4j.SessionWithContext) error {
query := `
MATCH (startNode)-[rel]->(endNode)
WITH
labels(startNode)[0] AS startNode,
type(rel) AS relType,
apoc.meta.cypher.types(rel) AS relProperties,
labels(endNode)[0] AS endNode,
count(*) AS count
RETURN relType, startNode, endNode, relProperties, count`
result, err := session.Run(ctx, query, nil)
if err != nil {
return fmt.Errorf("failed to extract relationships: %w", err)
}
for result.Next(ctx) {
record := result.Record()
relType, startNode, endNode := record.Values[0].(string), record.Values[1].(string), record.Values[2].(string)
properties, count := record.Values[3].(map[string]any), record.Values[4].(int64)
if relType == "" || count == 0 {
continue
}
relationship := types.Relationship{Type: relType, StartNode: startNode, EndNode: endNode, Count: count, Properties: []types.PropertyInfo{}}
for prop, propType := range properties {
relationship.Properties = append(relationship.Properties, types.PropertyInfo{Name: prop, Types: []string{propType.(string)}})
}
mu.Lock()
relationships = append(relationships, relationship)
stats.RelationshipsByType[relType] += count
stats.TotalRelationships += count
propCount := int64(len(relationship.Properties))
stats.TotalProperties += propCount
stats.PropertiesByRelType[relType] += propCount
mu.Unlock()
}
mu.Lock()
defer mu.Unlock()
if len(stats.RelationshipsByType) == 0 {
stats.RelationshipsByType = nil
}
if len(stats.PropertiesByRelType) == 0 {
stats.PropertiesByRelType = nil
}
return nil
},
},
}
var wg sync.WaitGroup
wg.Add(len(tasks))
for _, task := range tasks {
go func(task struct {
name string
fn func(session neo4j.SessionWithContext) error
}) {
defer wg.Done()
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
if err := task.fn(session); err != nil {
handleError(fmt.Errorf("task %s failed: %w", task.name, err))
}
}(task)
}
wg.Wait()
if firstErr != nil {
return nil, nil, nil, firstErr
}
return nodeLabels, relationships, stats, nil
}
// GetSchemaWithoutAPOC extracts schema information using native Cypher queries.
// This serves as a fallback for databases without APOC installed.
func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) {
nodePropsMap := make(map[string]map[string]map[string]bool)
relPropsMap := make(map[string]map[string]map[string]bool)
nodeCounts := make(map[string]int64)
relCounts := make(map[string]int64)
relConnectivity := make(map[string]types.RelConnectivityInfo)
var mu sync.Mutex
var firstErr error
ctx, cancel := context.WithCancel(ctx)
defer cancel()
handleError := func(err error) {
mu.Lock()
defer mu.Unlock()
if firstErr == nil {
firstErr = err
cancel()
}
}
tasks := []struct {
name string
fn func(session neo4j.SessionWithContext) error
}{
{
name: "node-schema",
fn: func(session neo4j.SessionWithContext) error {
countResult, err := session.Run(ctx, `MATCH (n) UNWIND labels(n) AS label RETURN label, count(*) AS count ORDER BY count DESC`, nil)
if err != nil {
return fmt.Errorf("node count query failed: %w", err)
}
var labelsList []string
mu.Lock()
for countResult.Next(ctx) {
record := countResult.Record()
label, count := record.Values[0].(string), record.Values[1].(int64)
nodeCounts[label] = count
labelsList = append(labelsList, label)
}
mu.Unlock()
if err = countResult.Err(); err != nil {
return fmt.Errorf("node count result error: %w", err)
}
for _, label := range labelsList {
propQuery := fmt.Sprintf(`MATCH (n:%s) WITH n LIMIT $sampleSize UNWIND keys(n) AS key WITH key, n[key] AS value WHERE value IS NOT NULL RETURN key, COLLECT(DISTINCT valueType(value)) AS types`, label)
propResult, err := session.Run(ctx, propQuery, map[string]any{"sampleSize": sampleSize})
if err != nil {
return fmt.Errorf("node properties query for label %s failed: %w", label, err)
}
mu.Lock()
if nodePropsMap[label] == nil {
nodePropsMap[label] = make(map[string]map[string]bool)
}
for propResult.Next(ctx) {
record := propResult.Record()
key, types := record.Values[0].(string), record.Values[1].([]any)
if nodePropsMap[label][key] == nil {
nodePropsMap[label][key] = make(map[string]bool)
}
for _, tp := range types {
nodePropsMap[label][key][tp.(string)] = true
}
}
mu.Unlock()
if err = propResult.Err(); err != nil {
return fmt.Errorf("node properties result error for label %s: %w", label, err)
}
}
return nil
},
},
{
name: "relationship-schema",
fn: func(session neo4j.SessionWithContext) error {
relQuery := `
MATCH (start)-[r]->(end)
WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count
RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount
ORDER BY totalCount DESC`
relResult, err := session.Run(ctx, relQuery, nil)
if err != nil {
return fmt.Errorf("relationship count query failed: %w", err)
}
var relTypesList []string
mu.Lock()
for relResult.Next(ctx) {
record := relResult.Record()
relType := record.Values[0].(string)
startLabel := ""
if record.Values[1] != nil {
startLabel = record.Values[1].(string)
}
endLabel := ""
if record.Values[2] != nil {
endLabel = record.Values[2].(string)
}
count := record.Values[3].(int64)
relCounts[relType] = count
relTypesList = append(relTypesList, relType)
if existing, ok := relConnectivity[relType]; !ok || count > existing.Count {
relConnectivity[relType] = types.RelConnectivityInfo{StartNode: startLabel, EndNode: endLabel, Count: count}
}
}
mu.Unlock()
if err = relResult.Err(); err != nil {
return fmt.Errorf("relationship count result error: %w", err)
}
for _, relType := range relTypesList {
propQuery := fmt.Sprintf(`MATCH ()-[r:%s]->() WITH r LIMIT $sampleSize WHERE size(keys(r)) > 0 UNWIND keys(r) AS key WITH key, r[key] AS value WHERE value IS NOT NULL RETURN key, COLLECT(DISTINCT valueType(value)) AS types`, relType)
propResult, err := session.Run(ctx, propQuery, map[string]any{"sampleSize": sampleSize})
if err != nil {
return fmt.Errorf("relationship properties query for type %s failed: %w", relType, err)
}
mu.Lock()
if relPropsMap[relType] == nil {
relPropsMap[relType] = make(map[string]map[string]bool)
}
for propResult.Next(ctx) {
record := propResult.Record()
key, propTypes := record.Values[0].(string), record.Values[1].([]any)
if relPropsMap[relType][key] == nil {
relPropsMap[relType][key] = make(map[string]bool)
}
for _, t := range propTypes {
relPropsMap[relType][key][t.(string)] = true
}
}
mu.Unlock()
if err = propResult.Err(); err != nil {
return fmt.Errorf("relationship properties result error for type %s: %w", relType, err)
}
}
return nil
},
},
}
var wg sync.WaitGroup
wg.Add(len(tasks))
for _, task := range tasks {
go func(task struct {
name string
fn func(session neo4j.SessionWithContext) error
}) {
defer wg.Done()
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
if err := task.fn(session); err != nil {
handleError(fmt.Errorf("task %s failed: %w", task.name, err))
}
}(task)
}
wg.Wait()
if firstErr != nil {
return nil, nil, nil, firstErr
}
nodeLabels, relationships, stats := helpers.ProcessNonAPOCSchema(nodeCounts, nodePropsMap, relCounts, relPropsMap, relConnectivity)
return nodeLabels, relationships, stats, nil
}
// extractDatabaseInfo retrieves general information about the Neo4j database instance.
func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, error) {
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
result, err := session.Run(ctx, "CALL dbms.components() YIELD name, versions, edition", nil)
if err != nil {
return nil, err
}
dbInfo := &types.DatabaseInfo{}
if result.Next(ctx) {
record := result.Record()
dbInfo.Name = record.Values[0].(string)
if versions, ok := record.Values[1].([]any); ok && len(versions) > 0 {
dbInfo.Version = versions[0].(string)
}
dbInfo.Edition = record.Values[2].(string)
}
return dbInfo, result.Err()
}
// extractConstraints fetches all schema constraints from the database.
func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error) {
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
result, err := session.Run(ctx, "SHOW CONSTRAINTS", nil)
if err != nil {
return nil, err
}
var constraints []types.Constraint
for result.Next(ctx) {
record := result.Record().AsMap()
constraint := types.Constraint{
Name: helpers.GetStringValue(record["name"]),
Type: helpers.GetStringValue(record["type"]),
EntityType: helpers.GetStringValue(record["entityType"]),
}
if labels, ok := record["labelsOrTypes"].([]any); ok && len(labels) > 0 {
constraint.Label = labels[0].(string)
}
if props, ok := record["properties"].([]any); ok {
constraint.Properties = helpers.ConvertToStringSlice(props)
}
constraints = append(constraints, constraint)
}
return constraints, result.Err()
}
// extractIndexes fetches all schema indexes from the database.
func (t Tool) extractIndexes(ctx context.Context) ([]types.Index, error) {
session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database})
defer session.Close(ctx)
result, err := session.Run(ctx, "SHOW INDEXES", nil)
if err != nil {
return nil, err
}
var indexes []types.Index
for result.Next(ctx) {
record := result.Record().AsMap()
index := types.Index{
Name: helpers.GetStringValue(record["name"]),
State: helpers.GetStringValue(record["state"]),
Type: helpers.GetStringValue(record["type"]),
EntityType: helpers.GetStringValue(record["entityType"]),
}
if labels, ok := record["labelsOrTypes"].([]any); ok && len(labels) > 0 {
index.Label = labels[0].(string)
}
if props, ok := record["properties"].([]any); ok {
index.Properties = helpers.ConvertToStringSlice(props)
}
indexes = append(indexes, index)
}
return indexes, result.Err()
}

View File

@@ -0,0 +1,99 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package neo4jschema
import (
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlNeo4j(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
exp := 30
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example with default cache expiration",
in: `
tools:
example_tool:
kind: neo4j-schema
source: my-neo4j-instance
description: some tool description
authRequired:
- my-google-auth-service
- other-auth-service
`,
want: server.ToolConfigs{
"example_tool": Config{
Name: "example_tool",
Kind: "neo4j-schema",
Source: "my-neo4j-instance",
Description: "some tool description",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
CacheExpireMinutes: nil,
},
},
},
{
desc: "cache expire minutes set explicitly",
in: `
tools:
example_tool:
kind: neo4j-schema
source: my-neo4j-instance
description: some tool description
cacheExpireMinutes: 30
`,
want: server.ToolConfigs{
"example_tool": Config{
Name: "example_tool",
Kind: "neo4j-schema",
Source: "my-neo4j-instance",
Description: "some tool description",
AuthRequired: []string{}, // Expect an empty slice, not nil.
CacheExpireMinutes: &exp,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -0,0 +1,127 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package types contains the shared data structures for Neo4j schema representation.
package types
// SchemaInfo represents the complete database schema.
type SchemaInfo struct {
NodeLabels []NodeLabel `json:"nodeLabels"`
Relationships []Relationship `json:"relationships"`
Constraints []Constraint `json:"constraints"`
Indexes []Index `json:"indexes"`
DatabaseInfo DatabaseInfo `json:"databaseInfo"`
Statistics Statistics `json:"statistics"`
Errors []string `json:"errors,omitempty"`
}
// NodeLabel represents a node label with its properties.
type NodeLabel struct {
Name string `json:"name"`
Properties []PropertyInfo `json:"properties"`
Count int64 `json:"count"`
}
// RelConnectivityInfo holds information about a relationship's start and end nodes,
// primarily used during schema extraction without APOC procedures.
type RelConnectivityInfo struct {
StartNode string
EndNode string
Count int64
}
// Relationship represents a relationship type with its properties.
type Relationship struct {
Type string `json:"type"`
Properties []PropertyInfo `json:"properties"`
StartNode string `json:"startNode,omitempty"`
EndNode string `json:"endNode,omitempty"`
Count int64 `json:"count"`
}
// PropertyInfo represents a property with its data types.
type PropertyInfo struct {
Name string `json:"name"`
Types []string `json:"types"`
Mandatory bool `json:"-"`
Unique bool `json:"-"`
Indexed bool `json:"-"`
}
// Constraint represents a database constraint.
type Constraint struct {
Name string `json:"name"`
Type string `json:"type"`
EntityType string `json:"entityType"`
Label string `json:"label,omitempty"`
Properties []string `json:"properties"`
}
// Index represents a database index.
type Index struct {
Name string `json:"name"`
State string `json:"state"`
Type string `json:"type"`
EntityType string `json:"entityType"`
Label string `json:"label,omitempty"`
Properties []string `json:"properties"`
}
// DatabaseInfo contains general database information.
type DatabaseInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Edition string `json:"edition,omitempty"`
}
// Statistics contains database statistics.
type Statistics struct {
TotalNodes int64 `json:"totalNodes"`
TotalRelationships int64 `json:"totalRelationships"`
TotalProperties int64 `json:"totalProperties"`
NodesByLabel map[string]int64 `json:"nodesByLabel"`
RelationshipsByType map[string]int64 `json:"relationshipsByType"`
PropertiesByLabel map[string]int64 `json:"propertiesByLabel"`
PropertiesByRelType map[string]int64 `json:"propertiesByRelType"`
}
// APOCSchemaResult represents the result from apoc.meta.schema().
type APOCSchemaResult struct {
Value map[string]APOCEntity `json:"value"`
}
// APOCEntity represents a node or relationship in APOC schema.
type APOCEntity struct {
Type string `json:"type"`
Count int64 `json:"count"`
Labels []string `json:"labels,omitempty"`
Properties map[string]APOCProperty `json:"properties"`
Relationships map[string]APOCRelationshipInfo `json:"relationships,omitempty"`
}
// APOCProperty represents property info from APOC.
type APOCProperty struct {
Type string `json:"type"`
Indexed bool `json:"indexed"`
Unique bool `json:"unique"`
Existence bool `json:"existence"`
}
// APOCRelationshipInfo represents relationship info from APOC.
type APOCRelationshipInfo struct {
Count int64 `json:"count"`
Direction string `json:"direction"`
Labels []string `json:"labels"`
Properties map[string]APOCProperty `json:"properties"`
}

View File

@@ -27,6 +27,8 @@ import (
"testing"
"time"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/tests"
)
@@ -39,6 +41,8 @@ var (
Neo4jPass = os.Getenv("NEO4J_PASS")
)
// getNeo4jVars retrieves necessary Neo4j connection details from environment variables.
// It fails the test if any required variable is not set.
func getNeo4jVars(t *testing.T) map[string]any {
switch "" {
case Neo4jDatabase:
@@ -60,6 +64,8 @@ func getNeo4jVars(t *testing.T) map[string]any {
}
}
// TestNeo4jToolEndpoints sets up an integration test server and tests the API endpoints
// for various Neo4j tools, including cypher execution and schema retrieval.
func TestNeo4jToolEndpoints(t *testing.T) {
sourceConfig := getNeo4jVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
@@ -67,7 +73,8 @@ func TestNeo4jToolEndpoints(t *testing.T) {
var args []string
// Write config into a file and pass it to command
// Write config into a file and pass it to the command.
// This configuration defines the data source and the tools to be tested.
toolsFile := map[string]any{
"sources": map[string]any{
"my-neo4j-instance": sourceConfig,
@@ -90,6 +97,22 @@ func TestNeo4jToolEndpoints(t *testing.T) {
"description": "A readonly cypher execution tool.",
"readOnly": true,
},
"my-schema-tool": map[string]any{
"kind": "neo4j-schema",
"source": "my-neo4j-instance",
"description": "A tool to get the Neo4j schema.",
},
"my-schema-tool-with-cache": map[string]any{
"kind": "neo4j-schema",
"source": "my-neo4j-instance",
"description": "A schema tool with a custom cache expiration.",
"cacheExpireMinutes": 10,
},
"my-populated-schema-tool": map[string]any{
"kind": "neo4j-schema",
"source": "my-neo4j-instance",
"description": "A tool to get the Neo4j schema from a populated DB.",
},
},
}
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
@@ -106,7 +129,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Test tool get endpoint
// Test tool `GET` endpoints to verify their manifests are correct.
tcs := []struct {
name string
api string
@@ -142,6 +165,28 @@ func TestNeo4jToolEndpoints(t *testing.T) {
},
},
},
{
name: "get my-schema-tool",
api: "http://127.0.0.1:5000/api/tool/my-schema-tool/",
want: map[string]any{
"my-schema-tool": map[string]any{
"description": "A tool to get the Neo4j schema.",
"parameters": []any{},
"authRequired": []any{},
},
},
},
{
name: "get my-schema-tool-with-cache",
api: "http://127.0.0.1:5000/api/tool/my-schema-tool-with-cache/",
want: map[string]any{
"my-schema-tool-with-cache": map[string]any{
"description": "A schema tool with a custom cache expiration.",
"parameters": []any{},
"authRequired": []any{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -170,7 +215,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
})
}
// Test tool invoke endpoint
// Test tool `invoke` endpoints to verify their functionality.
invokeTcs := []struct {
name string
api string
@@ -178,6 +223,8 @@ func TestNeo4jToolEndpoints(t *testing.T) {
want string
wantStatus int
wantErrorSubstring string
prepareData func(t *testing.T)
validateFunc func(t *testing.T, body string)
}{
{
name: "invoke my-simple-cypher-tool",
@@ -200,9 +247,225 @@ func TestNeo4jToolEndpoints(t *testing.T) {
wantStatus: http.StatusBadRequest,
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
},
{
name: "invoke my-schema-tool",
api: "http://127.0.0.1:5000/api/tool/my-schema-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatus: http.StatusOK,
validateFunc: func(t *testing.T, body string) {
var result map[string]any
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to unmarshal schema result: %v", err)
}
// Check for the presence of top-level keys in the schema response.
expectedKeys := []string{"nodeLabels", "relationships", "constraints", "indexes", "databaseInfo", "statistics"}
for _, key := range expectedKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q not found in schema response", key)
}
}
},
},
{
name: "invoke my-schema-tool-with-cache",
api: "http://127.0.0.1:5000/api/tool/my-schema-tool-with-cache/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatus: http.StatusOK,
validateFunc: func(t *testing.T, body string) {
var result map[string]any
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to unmarshal schema result: %v", err)
}
// Also check the structure of the schema response for the cached tool.
expectedKeys := []string{"nodeLabels", "relationships", "constraints", "indexes", "databaseInfo", "statistics"}
for _, key := range expectedKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q not found in schema response", key)
}
}
},
},
{
name: "invoke my-schema-tool with populated data",
api: "http://127.0.0.1:5000/api/tool/my-populated-schema-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatus: http.StatusOK,
prepareData: func(t *testing.T) {
ctx := context.Background()
driver, err := neo4j.NewDriverWithContext(Neo4jUri, neo4j.BasicAuth(Neo4jUser, Neo4jPass, ""))
if err != nil {
t.Fatalf("failed to create neo4j driver: %v", err)
}
// Helper to execute queries for setup and teardown.
execute := func(query string) {
session := driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: Neo4jDatabase})
defer session.Close(ctx)
// Use ExecuteWrite to ensure the query is committed before proceeding.
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) {
_, err := tx.Run(ctx, query, nil)
return nil, err
})
// Don't fail the test on teardown errors (e.g., entity doesn't exist).
if err != nil && !strings.Contains(query, "DROP") {
t.Fatalf("query failed: %s\nerror: %v", query, err)
}
}
// Teardown logic is deferred to ensure it runs even if the test fails.
// The driver will be closed at the end of this block.
t.Cleanup(func() {
execute("DROP CONSTRAINT PersonNameUnique IF EXISTS")
execute("DROP INDEX MovieTitleIndex IF EXISTS")
execute("MATCH (n) DETACH DELETE n")
if err := driver.Close(ctx); err != nil {
t.Errorf("failed to close driver during cleanup: %v", err)
}
})
// Setup: Create constraints, indexes, and data.
execute("MERGE (p:Person {name: 'Alice'}) MERGE (m:Movie {title: 'The Matrix'}) MERGE (p)-[:ACTED_IN]->(m)")
execute("CREATE CONSTRAINT PersonNameUnique IF NOT EXISTS FOR (p:Person) REQUIRE p.name IS UNIQUE")
execute("CREATE INDEX MovieTitleIndex IF NOT EXISTS FOR (m:Movie) ON (m.title)")
},
validateFunc: func(t *testing.T, body string) {
// Define structs for unmarshaling the detailed schema.
type Property struct {
Name string `json:"name"`
Types []string `json:"types"`
}
type NodeLabel struct {
Name string `json:"name"`
Properties []Property `json:"properties"`
}
type Relationship struct {
Type string `json:"type"`
StartNode string `json:"startNode"`
EndNode string `json:"endNode"`
}
type Constraint struct {
Name string `json:"name"`
Label string `json:"label"`
Properties []string `json:"properties"`
}
type Index struct {
Name string `json:"name"`
Label string `json:"label"`
Properties []string `json:"properties"`
}
type Schema struct {
NodeLabels []NodeLabel `json:"nodeLabels"`
Relationships []Relationship `json:"relationships"`
Constraints []Constraint `json:"constraints"`
Indexes []Index `json:"indexes"`
}
var schema Schema
if err := json.Unmarshal([]byte(body), &schema); err != nil {
t.Fatalf("failed to unmarshal schema json: %v\nResponse body: %s", err, body)
}
// --- Validate Node Labels and Properties ---
var personLabelFound, movieLabelFound bool
for _, l := range schema.NodeLabels {
if l.Name == "Person" {
personLabelFound = true
propFound := false
for _, p := range l.Properties {
if p.Name == "name" {
propFound = true
break
}
}
if !propFound {
t.Errorf("expected Person label to have 'name' property, but it was not found")
}
}
if l.Name == "Movie" {
movieLabelFound = true
propFound := false
for _, p := range l.Properties {
if p.Name == "title" {
propFound = true
break
}
}
if !propFound {
t.Errorf("expected Movie label to have 'title' property, but it was not found")
}
}
}
if !personLabelFound {
t.Error("expected to find 'Person' in nodeLabels")
}
if !movieLabelFound {
t.Error("expected to find 'Movie' in nodeLabels")
}
// --- Validate Relationships ---
relFound := false
for _, r := range schema.Relationships {
if r.Type == "ACTED_IN" && r.StartNode == "Person" && r.EndNode == "Movie" {
relFound = true
break
}
}
if !relFound {
t.Errorf("expected to find relationship '(:Person)-[:ACTED_IN]->(:Movie)', but it was not found")
}
// --- Validate Constraints ---
constraintFound := false
for _, c := range schema.Constraints {
if c.Name == "PersonNameUnique" && c.Label == "Person" {
propFound := false
for _, p := range c.Properties {
if p == "name" {
propFound = true
break
}
}
if propFound {
constraintFound = true
break
}
}
}
if !constraintFound {
t.Errorf("expected to find constraint 'PersonNameUnique' on Person(name), but it was not found")
}
// --- Validate Indexes ---
indexFound := false
for _, i := range schema.Indexes {
if i.Name == "MovieTitleIndex" && i.Label == "Movie" {
propFound := false
for _, p := range i.Properties {
if p == "title" {
propFound = true
break
}
}
if propFound {
indexFound = true
break
}
}
}
if !indexFound {
t.Errorf("expected to find index 'MovieTitleIndex' on Movie(title), but it was not found")
}
},
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Prepare data if a preparation function is provided.
if tc.prepareData != nil {
tc.prepareData(t)
}
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
@@ -224,7 +487,11 @@ func TestNeo4jToolEndpoints(t *testing.T) {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
if tc.validateFunc != nil {
// Use the custom validation function if provided.
tc.validateFunc(t, got)
} else if got != tc.want {
// Otherwise, perform a direct string comparison.
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
} else {