From 8a1224b9e0145c4e214d42f14f5308b508ea27ce Mon Sep 17 00:00:00 2001 From: Michael Hunger Date: Tue, 14 Jan 2025 17:17:18 +0100 Subject: [PATCH] feat: Added Neo4j Source and Tool (#189) - configure neo4j source with url, username, password, database - configure neo4j tools with cypher statement and paramters - tests based on the postgres tests - neo4j.yaml for integration tests --------- Co-authored-by: duwenxin --- .ci/integration.cloudbuild.yaml | 22 ++++ README.md | 28 +++++ docs/sources/README.md | 1 + docs/sources/neo4j.md | 40 ++++++ docs/tools/README.md | 3 + docs/tools/neo4j-cypher.md | 59 +++++++++ go.mod | 1 + go.sum | 2 + internal/server/config.go | 14 +++ internal/sources/neo4j/neo4j.go | 99 +++++++++++++++ internal/sources/neo4j/neo4j_test.go | 68 ++++++++++ internal/tools/neo4j/neo4j.go | 132 ++++++++++++++++++++ internal/tools/neo4j/neo4j_test.go | 79 ++++++++++++ tests/neo4j_integration_test.go | 178 +++++++++++++++++++++++++++ 14 files changed, 726 insertions(+) create mode 100644 docs/sources/neo4j.md create mode 100644 docs/tools/neo4j-cypher.md create mode 100644 internal/sources/neo4j/neo4j.go create mode 100644 internal/sources/neo4j/neo4j_test.go create mode 100644 internal/tools/neo4j/neo4j.go create mode 100644 internal/tools/neo4j/neo4j_test.go create mode 100644 tests/neo4j_integration_test.go diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 1cd03f7479..a953f33071 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -101,6 +101,23 @@ steps: - | go test -race -v -tags=integration,spanner ./tests + - id: "neo4j" + name: golang:1 + waitFor: ["install-dependencies"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "NEO4J_DATABASE=$_NEO4J_DATABASE" + - "NEO4J_URI=$_NEO4J_URI" + secretEnv: ["NEO4J_USER", "NEO4J_PASS"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + go test -race -v -tags=integration,neo4j ./tests + availableSecrets: secretManager: - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest @@ -117,6 +134,10 @@ availableSecrets: env: POSTGRES_PASS - versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest env: CLIENT_ID + - versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest + env: NEO4J_USER + - versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest + env: NEO4J_PASS options: logging: CLOUD_LOGGING_ONLY @@ -135,3 +156,4 @@ substitutions: _POSTGRES_HOST: 127.0.0.1 _POSTGRES_PORT: "5432" _SPANNER_INSTANCE: "spanner-testing" + _NEO4J_DATABASE: "neo4j" diff --git a/README.md b/README.md index e2226f2584..06cc23a9db 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,18 @@ sources: database: my_db ``` +Example for Neo4j + +```yaml +sources: + my-neo4j-source: + kind: neo4j + uri: neo4j+s://my-neo4j-host:7687 + user: neo4j + password: my-password + database: my_db +``` + For more details on configuring different types of sources, see the [Source documentation.](docs/sources/README.md) @@ -202,6 +214,22 @@ tools: description: 'id' represents the unique ID for each flight. ``` +Neo4j-Cypher Example + +```yaml +tools: + get_movies_in_year: + kind: neo4j-cypher + source: my-neo4j-instance + description: > + Use this tool to list all movies titles in a given year. + statement: "MATCH (m:Movie) WHERE m.year = $year RETURN m.title" + parameters: + - name: "year" + type: integer + description: 'year' represents a 4 digit year since 1900 up to the current year +``` + For more details on configuring different types of tools, see the [Tool documentation.](docs/tools/README.md) diff --git a/docs/sources/README.md b/docs/sources/README.md index 26f715b72a..206f025f25 100644 --- a/docs/sources/README.md +++ b/docs/sources/README.md @@ -30,3 +30,4 @@ We currently support the following types of kinds of sources: PostgreSQL instance. * [postgres](./postgres.md) - Connect to any PostgreSQL compatible database. * [spanner](./spanner.md) - Connect to a Spanner database. +* [neo4j](./neo4j.md) - Connect to a Neo4j instance. \ No newline at end of file diff --git a/docs/sources/neo4j.md b/docs/sources/neo4j.md new file mode 100644 index 0000000000..f90e3800c5 --- /dev/null +++ b/docs/sources/neo4j.md @@ -0,0 +1,40 @@ +# Neo4j Source + +[Neo4j][neo4j-docs] is a powerful, open source graph database +system with over 15 years of active development that has earned it a strong +reputation for reliability, feature robustness, and performance. + +[neo4j-docs]: https://neo4j.com/docs + +## Requirements + +### Database User + +This source only uses standard authentication. You will need to [create a +Neo4j user][neo4j-users] to log in to the database with, or use the default `neo4j` user if available. + +[neo4j-users]: https://neo4j.com/docs/operations-manual/current/authentication-authorization/manage-users/ + +## Example + +```yaml +sources: + my-neo4j-source: + kind: "neo4j" + uri: "neo4j+s://xxxx.databases.neo4j.io:7687" + user: "neo4j" + password: "my-password" + database: "neo4j" +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-----------|:--------:|:------------:|---------------------------------------------------------------------| +| kind | string | true | Must be "neo4j". | +| uri | string | true | Connect URI ("bolt://localhost", "neo4j+s://xxx.databases.neo4j.io") | +| user | string | true | Name of the Neo4j user to connect as (e.g. "neo4j"). | +| password | string | true | Password of the Neo4j user (e.g. "my-password"). | +| database | string | true | Name of the Neo4j database to connect to (e.g. "neo4j"). | + + diff --git a/docs/tools/README.md b/docs/tools/README.md index 890fd1fe30..0e0899973c 100644 --- a/docs/tools/README.md +++ b/docs/tools/README.md @@ -50,6 +50,9 @@ We currently support the following types of kinds of tools: PostgreSQL-compatible database. * [spanner](./spanner.md) - Run a Spanner (either googlesql or postgresql) statement againts Spanner database. +* [neo4j-cypher](./neo4j-cypher.md) - Run a Cypher statement against a + Neo4j database. + ## Specifying Parameters diff --git a/docs/tools/neo4j-cypher.md b/docs/tools/neo4j-cypher.md new file mode 100644 index 0000000000..dab63c1873 --- /dev/null +++ b/docs/tools/neo4j-cypher.md @@ -0,0 +1,59 @@ +# Neo4j Cypher Tool + +A "neo4j-cypher" tool executes a pre-defined Cypher statement against a Neo4j database. It's compatible with any of the following sources: +- [neo4j](../sources/neo4j.md) + +The specified Cypher statement is executed as a [parameterized statement][neo4j-parameters], +and specified parameters will be used according to their name: e.g. `$id`. + +[neo4j-parameters]: https://neo4j.com/docs/cypher-manual/current/syntax/parameters/ + +## Example + +```yaml +tools: + search_movies_by_actor: + kind: neo4j-cypher + source: my-neo4j-movies-instance + statement: | + MATCH (m:Movie)<-[:ACTED_IN]-(p:Person) + WHERE p.name = $name AND m.year > $year + RETURN m.title, m.year + LIMIT 10 + description: | + Use this tool to get a list of movies for a specific actor and a given minium release year. + Takes an full actor name, e.g. "Tom Hanks" and a year e.g 1993 and returns a list of movie titles and release years. + Do NOT use this tool with a movie title. Do NOT guess an actor name, Do NOT guess a year. + A actor name is a fully qualified name with first and last name separated by a space. + For example, if given "Hanks, Tom" the actor name is "Tom Hanks". + If the tool returns more than one option choose the most recent movies. + Example: + {{ + "name": "Meg Ryan", + "year": 1993 + }} + Example: + {{ + "name": "Clint Eastwood", + "year": 2000 + }} + parameters: + - name: name + type: string + description: Full actor name, "firstname lastname" + - name: year + type: integer + description: 4 digit number starting in 1900 up to the current year +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|----------:|:------------:|----------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "neo4j-cypher". | +| source | string | true | Name of the source the Cypher query should execute on. | +| description | string | true | Description of the tool | +| statement | string | true | Cypher statement to execute | +| parameters | parameter | true | List of [parameters](README.md#specifying-parameters) that will be used with the Cypher statement. | + + diff --git a/go.mod b/go.mod index e9fdaff6d0..1e96a1cf6c 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.1 + github.com/neo4j/neo4j-go-driver/v5 v5.26.0 github.com/spf13/cobra v1.8.1 go.opentelemetry.io/contrib/propagators/autoprop v0.58.0 go.opentelemetry.io/otel v1.33.0 diff --git a/go.sum b/go.sum index 3718e33363..e7612b7ee1 100644 --- a/go.sum +++ b/go.sum @@ -918,6 +918,8 @@ github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3ao github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= +github.com/neo4j/neo4j-go-driver/v5 v5.26.0 h1:GB3o4VtIGsvU+RmfgvF7L6nt1IpbPZaGtPMtPSOKmvc= +github.com/neo4j/neo4j-go-driver/v5 v5.26.0/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY= github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= diff --git a/internal/server/config.go b/internal/server/config.go index 83e46e46f7..45388aeba6 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -22,9 +22,11 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres" spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" + neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j" "github.com/googleapis/genai-toolbox/internal/tools/postgressql" "github.com/googleapis/genai-toolbox/internal/tools/spanner" "gopkg.in/yaml.v3" @@ -156,6 +158,12 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error { return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) } (*c)[name] = actual + case neo4jrc.SourceKind: + actual := neo4jrc.Config{Name: name} + if err := n.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + } + (*c)[name] = actual default: return fmt.Errorf("%q is not a valid kind of data source", k.Kind) } @@ -235,6 +243,12 @@ func (c *ToolConfigs) UnmarshalYAML(node *yaml.Node) error { return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) } (*c)[name] = actual + case neo4jtool.ToolKind: + actual := neo4jtool.Config{Name: name} + if err := n.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + } + (*c)[name] = actual default: return fmt.Errorf("%q is not a valid kind of tool", k.Kind) } diff --git a/internal/sources/neo4j/neo4j.go b/internal/sources/neo4j/neo4j.go new file mode 100644 index 0000000000..5f3de6e9c6 --- /dev/null +++ b/internal/sources/neo4j/neo4j.go @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package neo4j + +import ( + "context" + "fmt" + + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "go.opentelemetry.io/otel/trace" +) + +const SourceKind string = "neo4j" + +// validate interface +var _ sources.SourceConfig = Config{} + +type Config struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Uri string `yaml:"uri"` + User string `yaml:"user"` + Password string `yaml:"password"` + Database string `yaml:"database"` +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + driver, err := initNeo4jDriver(ctx, tracer, r.Uri, r.User, r.Password, r.Name) + if err != nil { + return nil, fmt.Errorf("Unable to create driver: %w", err) + } + + err = driver.VerifyConnectivity(context.Background()) + if err != nil { + return nil, fmt.Errorf("Unable to connect successfully: %w", err) + } + + if r.Database == "" { + r.Database = "neo4j" + } + s := &Source{ + Name: r.Name, + Kind: SourceKind, + Database: r.Database, + Driver: driver, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Database string `yaml:"database"` + Driver neo4j.DriverWithContext +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) Neo4jDriver() neo4j.DriverWithContext { + return s.Driver +} + +func (s *Source) Neo4jDatabase() string { + return s.Database +} + +func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) { + //nolint:all // Reassigned ctx + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + defer span.End() + + auth := neo4j.BasicAuth(user, password, "") + driver, err := neo4j.NewDriverWithContext(uri, auth) + if err != nil { + return nil, fmt.Errorf("unable to create connection driver: %w", err) + } + return driver, nil +} diff --git a/internal/sources/neo4j/neo4j_test.go b/internal/sources/neo4j/neo4j_test.go new file mode 100644 index 0000000000..a9c165ad25 --- /dev/null +++ b/internal/sources/neo4j/neo4j_test.go @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package neo4j_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/neo4j" + "github.com/googleapis/genai-toolbox/internal/testutils" + "gopkg.in/yaml.v3" +) + +func TestParseFromYamlNeo4j(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-neo4j-instance: + kind: neo4j + uri: neo4j+s://my-host:7687 + database: my_db + `, + want: server.SourceConfigs{ + "my-neo4j-instance": neo4j.Config{ + Name: "my-neo4j-instance", + Kind: neo4j.SourceKind, + Uri: "neo4j+s://my-host:7687", + Database: "my_db", + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } + +} diff --git a/internal/tools/neo4j/neo4j.go b/internal/tools/neo4j/neo4j.go new file mode 100644 index 0000000000..216545658a --- /dev/null +++ b/internal/tools/neo4j/neo4j.go @@ -0,0 +1,132 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package neo4j + +import ( + "context" + "fmt" + "strings" + + neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +const ToolKind string = "neo4j-cypher" + +type compatibleSource interface { + Neo4jDriver() neo4j.DriverWithContext + Neo4jDatabase() string +} + +// validate compatible sources are still compatible +var _ compatibleSource = &neo4jsc.Source{} + +var compatibleSources = [...]string{neo4jsc.SourceKind} + +type Config struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Source string `yaml:"source"` + Description string `yaml:"description"` + Statement string `yaml:"statement"` + Parameters tools.Parameters `yaml:"parameters"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return ToolKind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources) + } + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: ToolKind, + Parameters: cfg.Parameters, + Statement: cfg.Statement, + Driver: s.Neo4jDriver(), + Database: s.Neo4jDatabase(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()}, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Parameters tools.Parameters `yaml:"parameters"` + AuthRequired []string `yaml:"authRequired"` + + Driver neo4j.DriverWithContext + Database string + Statement string + manifest tools.Manifest +} + +func (t Tool) Invoke(params tools.ParamValues) (string, error) { + paramsMap := params.AsMap() + + fmt.Printf("Invoked tool %s\n", t.Name) + ctx := context.Background() + config := neo4j.ExecuteQueryWithDatabase(t.Database) + results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap, + neo4j.EagerResultTransformer, config) + if err != nil { + return "", fmt.Errorf("unable to execute query: %w", err) + } + + var out strings.Builder + keys := results.Keys + records := results.Records + for _, record := range records { + out.WriteString("\n") // fmt.Sprintf("Row: %d\n", row)) + for col, value := range record.Values { + out.WriteString(fmt.Sprintf("\t%s: %s\n", keys[col], value)) + } + } + return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, paramsMap, out.String()), nil +} + +func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claimsMap) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) Authorized(verifiedAuthSources []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources) +} diff --git a/internal/tools/neo4j/neo4j_test.go b/internal/tools/neo4j/neo4j_test.go new file mode 100644 index 0000000000..cfb3d19005 --- /dev/null +++ b/internal/tools/neo4j/neo4j_test.go @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package neo4j_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/neo4j" + "gopkg.in/yaml.v3" +) + +func TestParseFromYamlNeo4j(t *testing.T) { + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: neo4j-cypher + source: my-neo4j-instance + description: some tool description + statement: | + MATCH (c:Country) WHERE c.name = $country RETURN c.id as id; + parameters: + - name: country + type: string + description: country parameter description + `, + want: server.ToolConfigs{ + "example_tool": neo4j.Config{ + Name: "example_tool", + Kind: neo4j.ToolKind, + Source: "my-neo4j-instance", + Description: "some tool description", + Statement: "MATCH (c:Country) WHERE c.name = $country RETURN c.id as id;\n", + Parameters: []tools.Parameter{ + tools.NewStringParameter("country", "country parameter description"), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/tests/neo4j_integration_test.go b/tests/neo4j_integration_test.go new file mode 100644 index 0000000000..47fc5d151b --- /dev/null +++ b/tests/neo4j_integration_test.go @@ -0,0 +1,178 @@ +//go:build integration && neo4j + +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "reflect" + "regexp" + "testing" + "time" +) + +var ( + NEO4J_DATABASE = os.Getenv("NEO4J_DATABASE") + NEO4J_URI = os.Getenv("NEO4J_URI") + NEO4J_USER = os.Getenv("NEO4J_USER") + NEO4J_PASS = os.Getenv("NEO4J_PASS") +) + +func requireNeo4jVars(t *testing.T) { + switch "" { + case NEO4J_DATABASE: + t.Fatal("'NEO4J_DATABASE' not set") + case NEO4J_URI: + t.Fatal("'NEO4J_URI' not set") + case NEO4J_USER: + t.Fatal("'NEO4J_USER' not set") + case NEO4J_PASS: + t.Fatal("'NEO4J_PASS' not set") + } +} + +func TestNeo4j(t *testing.T) { + requireNeo4jVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + // Write config into a file and pass it to command + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-neo4j-instance": map[string]any{ + "kind": "neo4j", + "uri": NEO4J_URI, + "database": NEO4J_DATABASE, + "user": NEO4J_USER, + "password": NEO4J_PASS, + }, + }, + "tools": map[string]any{ + "my-simple-cypher-tool": map[string]any{ + "kind": "neo4j-cypher", + "source": "my-neo4j-instance", + "description": "Simple tool to test end to end functionality.", + "statement": "RETURN 1 as a;", + }, + }, + } + cmd, cleanup, err := StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`)) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Test tool get endpoint + tcs := []struct { + name string + api string + want map[string]any + }{ + { + name: "get my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/", + want: map[string]any{ + "my-simple-cypher-tool": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "parameters": []any{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Get(tc.api) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("response status code is not 200") + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["tools"] + if !ok { + t.Fatalf("unable to find tools in response body") + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } + + // Test tool invoke endpoint + invokeTcs := []struct { + name string + api string + requestBody io.Reader + want string + }{ + { + name: "invoke my-simple-cypher-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: "Stub tool call for \"my-simple-cypher-tool\"! Parameters parsed: map[] \n Output: \n\ta: %!s(int64=1)\n", + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Post(tc.api, "application/json", tc.requestBody) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if got != tc.want { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + }) + } +}