mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
fix(tools/neo4j): Implement value conversion for Neo4j types to JSON-compatible (#1428)
This pull request introduces a utility function to standardize the conversion of Neo4j driver values into JSON-compatible types. The new `ConvertValue` function is added to the `helpers` package, and is now used in both the `neo4jcypher` and `neo4jexecutecypher` tools to ensure consistent output formatting. Comprehensive unit tests for this function are also included. Additionally, a new `ValueType` interface is defined to generalize Neo4j value stringification. **Helpers and Value Conversion:** * Added the `ConvertValue` function to `internal/tools/neo4j/neo4jschema/helpers/helpers.go` to recursively convert Neo4j driver types (including nodes, relationships, paths, points, and temporal types) into JSON-compatible Go values. This ensures proper serialization of complex Neo4j types. * Defined a `ValueType` interface in `internal/tools/neo4j/neo4jschema/types/types.go` to generalize the stringification of Neo4j value types. **Integration with Tools:** * Updated the `Invoke` methods in both `neo4jcypher` and `neo4jexecutecypher` tools to use `helpers.ConvertValue` when processing Neo4j query results, ensuring consistent and correct output formatting. [[1]](diffhunk://#diff-b3e792b742cb92c92d1f5136b444c8fe0a7ec0376920868182dc88f13002e8eeL138-R139) [[2]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3L160-R161) * Added the necessary imports for the helpers package in the affected files. [[1]](diffhunk://#diff-b3e792b742cb92c92d1f5136b444c8fe0a7ec0376920868182dc88f13002e8eeR23) [[2]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3R26) [[3]](diffhunk://#diff-97a47eb63017102fc7084aa79689e22ef8cb6d1952945177058f118d57be306fR26) **Testing:** * Added extensive tests for `ConvertValue` in `helpers_test.go`, covering all supported Neo4j types, primitives, slices, maps, and unhandled types. * Included required imports in the test file for Neo4j driver and time handling. **Relates to:** https://github.com/googleapis/genai-toolbox/issues/1344 Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -135,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
for _, record := range records {
|
||||
vMap := make(map[string]any)
|
||||
for col, value := range record.Values {
|
||||
vMap[keys[col]] = value
|
||||
vMap[keys[col]] = helpers.ConvertValue(value)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
)
|
||||
|
||||
@@ -157,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
for _, record := range records {
|
||||
vMap := make(map[string]any)
|
||||
for col, value := range record.Values {
|
||||
vMap[keys[col]] = value
|
||||
vMap[keys[col]] = helpers.ConvertValue(value)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
)
|
||||
|
||||
// ConvertToStringSlice converts a slice of any type to a slice of strings.
|
||||
@@ -289,3 +290,73 @@ func sortAndClean(nodeLabels []types.NodeLabel, relationships []types.Relationsh
|
||||
stats.PropertiesByRelType = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertValue converts Neo4j value to JSON-compatible value.
|
||||
func ConvertValue(value any) any {
|
||||
switch v := value.(type) {
|
||||
case nil, neo4j.InvalidValue:
|
||||
return nil
|
||||
case bool, string, int, int8, int16, int32, int64, float32, float64:
|
||||
return v
|
||||
case neo4j.Date, neo4j.LocalTime, neo4j.Time,
|
||||
neo4j.LocalDateTime, neo4j.Duration:
|
||||
if iv, ok := v.(types.ValueType); ok {
|
||||
return iv.String()
|
||||
}
|
||||
case neo4j.Node:
|
||||
return map[string]any{
|
||||
"elementId": v.GetElementId(),
|
||||
"labels": v.Labels,
|
||||
"properties": ConvertValue(v.GetProperties()),
|
||||
}
|
||||
case neo4j.Relationship:
|
||||
return map[string]any{
|
||||
"elementId": v.GetElementId(),
|
||||
"type": v.Type,
|
||||
"startElementId": v.StartElementId,
|
||||
"endElementId": v.EndElementId,
|
||||
"properties": ConvertValue(v.GetProperties()),
|
||||
}
|
||||
case neo4j.Entity:
|
||||
return map[string]any{
|
||||
"elementId": v.GetElementId(),
|
||||
"properties": ConvertValue(v.GetProperties()),
|
||||
}
|
||||
case neo4j.Path:
|
||||
var nodes []any
|
||||
var relationships []any
|
||||
for _, r := range v.Relationships {
|
||||
relationships = append(relationships, ConvertValue(r))
|
||||
}
|
||||
for _, n := range v.Nodes {
|
||||
nodes = append(nodes, ConvertValue(n))
|
||||
}
|
||||
return map[string]any{
|
||||
"nodes": nodes,
|
||||
"relationships": relationships,
|
||||
}
|
||||
case neo4j.Record:
|
||||
m := make(map[string]any)
|
||||
for i, key := range v.Keys {
|
||||
m[key] = ConvertValue(v.Values[i])
|
||||
}
|
||||
return m
|
||||
case neo4j.Point2D:
|
||||
return map[string]any{"x": v.X, "y": v.Y, "srid": v.SpatialRefId}
|
||||
case neo4j.Point3D:
|
||||
return map[string]any{"x": v.X, "y": v.Y, "z": v.Z, "srid": v.SpatialRefId}
|
||||
case []any:
|
||||
arr := make([]any, len(v))
|
||||
for i, elem := range v {
|
||||
arr[i] = ConvertValue(elem)
|
||||
}
|
||||
return arr
|
||||
case map[string]any:
|
||||
m := make(map[string]any)
|
||||
for key, val := range v {
|
||||
m[key] = ConvertValue(val)
|
||||
}
|
||||
return m
|
||||
}
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
|
||||
@@ -16,9 +16,11 @@ package helpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
)
|
||||
|
||||
func TestHelperFunctions(t *testing.T) {
|
||||
@@ -382,3 +384,176 @@ func TestProcessNonAPOCSchema(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input any
|
||||
want any
|
||||
}{
|
||||
{
|
||||
name: "nil value",
|
||||
input: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "neo4j.InvalidValue",
|
||||
input: neo4j.InvalidValue{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "primitive bool",
|
||||
input: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "primitive int",
|
||||
input: int64(42),
|
||||
want: int64(42),
|
||||
},
|
||||
{
|
||||
name: "primitive float",
|
||||
input: 3.14,
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
name: "primitive string",
|
||||
input: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "neo4j.Date",
|
||||
input: neo4j.Date(time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC)),
|
||||
want: "2024-06-01",
|
||||
},
|
||||
{
|
||||
name: "neo4j.LocalTime",
|
||||
input: neo4j.LocalTime(time.Date(0, 0, 0, 12, 34, 56, 0, time.Local)),
|
||||
want: "12:34:56",
|
||||
},
|
||||
{
|
||||
name: "neo4j.Time",
|
||||
input: neo4j.Time(time.Date(0, 0, 0, 1, 2, 3, 0, time.UTC)),
|
||||
want: "01:02:03Z",
|
||||
},
|
||||
{
|
||||
name: "neo4j.LocalDateTime",
|
||||
input: neo4j.LocalDateTime(time.Date(2024, 6, 1, 10, 20, 30, 0, time.Local)),
|
||||
want: "2024-06-01T10:20:30",
|
||||
},
|
||||
{
|
||||
name: "neo4j.Duration",
|
||||
input: neo4j.Duration{Months: 1, Days: 2, Seconds: 3, Nanos: 4},
|
||||
want: "P1M2DT3.000000004S",
|
||||
},
|
||||
{
|
||||
name: "neo4j.Point2D",
|
||||
input: neo4j.Point2D{X: 1.1, Y: 2.2, SpatialRefId: 1234},
|
||||
want: map[string]any{"x": 1.1, "y": 2.2, "srid": uint32(1234)},
|
||||
},
|
||||
{
|
||||
name: "neo4j.Point3D",
|
||||
input: neo4j.Point3D{X: 1.1, Y: 2.2, Z: 3.3, SpatialRefId: 5467},
|
||||
want: map[string]any{"x": 1.1, "y": 2.2, "z": 3.3, "srid": uint32(5467)},
|
||||
},
|
||||
{
|
||||
name: "neo4j.Node (handled by Entity case, losing labels)",
|
||||
input: neo4j.Node{
|
||||
ElementId: "element-1",
|
||||
Labels: []string{"Person"},
|
||||
Props: map[string]any{"name": "Alice"},
|
||||
},
|
||||
want: map[string]any{
|
||||
"elementId": "element-1",
|
||||
"labels": []string{"Person"},
|
||||
"properties": map[string]any{"name": "Alice"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "neo4j.Relationship (handled by Entity case, losing type/endpoints)",
|
||||
input: neo4j.Relationship{
|
||||
ElementId: "element-2",
|
||||
StartElementId: "start-1",
|
||||
EndElementId: "end-1",
|
||||
Type: "KNOWS",
|
||||
Props: map[string]any{"since": 2024},
|
||||
},
|
||||
want: map[string]any{
|
||||
"elementId": "element-2",
|
||||
"properties": map[string]any{"since": 2024},
|
||||
"startElementId": "start-1",
|
||||
"endElementId": "end-1",
|
||||
"type": "KNOWS",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "neo4j.Path (elements handled by Entity case)",
|
||||
input: func() neo4j.Path {
|
||||
node1 := neo4j.Node{ElementId: "n10", Labels: []string{"A"}, Props: map[string]any{"p1": "v1"}}
|
||||
node2 := neo4j.Node{ElementId: "n11", Labels: []string{"B"}, Props: map[string]any{"p2": "v2"}}
|
||||
rel1 := neo4j.Relationship{ElementId: "r12", StartElementId: "n10", EndElementId: "n11", Type: "REL", Props: map[string]any{"p3": "v3"}}
|
||||
return neo4j.Path{
|
||||
Nodes: []neo4j.Node{node1, node2},
|
||||
Relationships: []neo4j.Relationship{rel1},
|
||||
}
|
||||
}(),
|
||||
want: map[string]any{
|
||||
"nodes": []any{
|
||||
map[string]any{
|
||||
"elementId": "n10",
|
||||
"properties": map[string]any{"p1": "v1"},
|
||||
"labels": []string{"A"},
|
||||
},
|
||||
map[string]any{
|
||||
"elementId": "n11",
|
||||
"properties": map[string]any{"p2": "v2"},
|
||||
"labels": []string{"B"},
|
||||
},
|
||||
},
|
||||
"relationships": []any{
|
||||
map[string]any{
|
||||
"elementId": "r12",
|
||||
"properties": map[string]any{"p3": "v3"},
|
||||
"startElementId": "n10",
|
||||
"endElementId": "n11",
|
||||
"type": "REL",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "slice of primitives",
|
||||
input: []any{"a", 1, true},
|
||||
want: []any{"a", 1, true},
|
||||
},
|
||||
{
|
||||
name: "slice of mixed types",
|
||||
input: []any{"a", neo4j.Date(time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC))},
|
||||
want: []any{"a", "2024-06-01"},
|
||||
},
|
||||
{
|
||||
name: "map of primitives",
|
||||
input: map[string]any{"foo": 1, "bar": "baz"},
|
||||
want: map[string]any{"foo": 1, "bar": "baz"},
|
||||
},
|
||||
{
|
||||
name: "map with nested neo4j type",
|
||||
input: map[string]any{"date": neo4j.Date(time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC))},
|
||||
want: map[string]any{"date": "2024-06-01"},
|
||||
},
|
||||
{
|
||||
name: "unhandled type",
|
||||
input: struct{ X int }{X: 5},
|
||||
want: "{5}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertValue(tt.input)
|
||||
if !cmp.Equal(got, tt.want) {
|
||||
t.Errorf("ConvertValue() mismatch (-want +got):\n%s", cmp.Diff(tt.want, got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,11 @@
|
||||
// Package types contains the shared data structures for Neo4j schema representation.
|
||||
package types
|
||||
|
||||
// ValueType interface representing a Neo4j value.
|
||||
type ValueType interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// SchemaInfo represents the complete database schema.
|
||||
type SchemaInfo struct {
|
||||
NodeLabels []NodeLabel `json:"nodeLabels"`
|
||||
|
||||
Reference in New Issue
Block a user