mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
feat: Added Neo4j Source and Tool (#189)
- configure neo4j source with url, username, password, database - configure neo4j tools with cypher statement and paramters - tests based on the postgres tests - neo4j.yaml for integration tests --------- Co-authored-by: duwenxin <duwenxin@google.com>
This commit is contained in:
132
internal/tools/neo4j/neo4j.go
Normal file
132
internal/tools/neo4j/neo4j.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4j
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const ToolKind string = "neo4j-cypher"
|
||||
|
||||
type compatibleSource interface {
|
||||
Neo4jDriver() neo4j.DriverWithContext
|
||||
Neo4jDatabase() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &neo4jsc.Source{}
|
||||
|
||||
var compatibleSources = [...]string{neo4jsc.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Driver: s.Neo4jDriver(),
|
||||
Database: s.Neo4jDatabase(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Driver neo4j.DriverWithContext
|
||||
Database string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
ctx := context.Background()
|
||||
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
||||
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap,
|
||||
neo4j.EagerResultTransformer, config)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
keys := results.Keys
|
||||
records := results.Records
|
||||
for _, record := range records {
|
||||
out.WriteString("\n") // fmt.Sprintf("Row: %d\n", row))
|
||||
for col, value := range record.Values {
|
||||
out.WriteString(fmt.Sprintf("\t%s: %s\n", keys[col], value))
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, paramsMap, out.String()), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
79
internal/tools/neo4j/neo4j_test.go
Normal file
79
internal/tools/neo4j/neo4j_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4j_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: neo4j-cypher
|
||||
source: my-neo4j-instance
|
||||
description: some tool description
|
||||
statement: |
|
||||
MATCH (c:Country) WHERE c.name = $country RETURN c.id as id;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: country parameter description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": neo4j.Config{
|
||||
Name: "example_tool",
|
||||
Kind: neo4j.ToolKind,
|
||||
Source: "my-neo4j-instance",
|
||||
Description: "some tool description",
|
||||
Statement: "MATCH (c:Country) WHERE c.name = $country RETURN c.id as id;\n",
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "country parameter description"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user