mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
Compare commits
29 Commits
spanner-cr
...
kuzu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff6ddaf73b | ||
|
|
4f725dc85e | ||
|
|
0faf75db95 | ||
|
|
5e08aeb7ed | ||
|
|
a00880aeda | ||
|
|
b72a4b9855 | ||
|
|
e6e1545a97 | ||
|
|
48f49c5e9a | ||
|
|
6ab871ea93 | ||
|
|
78ec2bb52e | ||
|
|
da5e130819 | ||
|
|
ce831a47c3 | ||
|
|
a0c4483a47 | ||
|
|
7aae8d50b3 | ||
|
|
b41d512759 | ||
|
|
898403f753 | ||
|
|
38db9bd62b | ||
|
|
376c3bec92 | ||
|
|
8190c92818 | ||
|
|
064adf2c79 | ||
|
|
8fc63c3c73 | ||
|
|
8a9344ac06 | ||
|
|
780e67fd22 | ||
|
|
54dddb772b | ||
|
|
20e2c1703f | ||
|
|
9f545ce9c4 | ||
|
|
3ef18e1713 | ||
|
|
0615e3a4e5 | ||
|
|
4ca4b96223 |
@@ -532,6 +532,25 @@ steps:
|
||||
tidb \
|
||||
tidbsql tidbexecutesql
|
||||
|
||||
- id: "kuzu"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Kuzu" \
|
||||
kuzu \
|
||||
kuzu
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
|
||||
@@ -61,6 +61,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/kuzu/kuzucypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters"
|
||||
@@ -113,6 +114,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/kuzu"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
|
||||
80
docs/en/resources/sources/kuzu.md
Normal file
80
docs/en/resources/sources/kuzu.md
Normal file
@@ -0,0 +1,80 @@
|
||||
---
|
||||
title: "Kùzu"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Kùzu is an open-source, embedded graph database built for query speed and scalability, optimized for complex join-heavy analytical workloads using the Cypher query language.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[Kuzu](https://kuzudb.com/) is an embedded graph database designed for high query speed and scalability, optimized for complex, join-heavy analytical workloads on large graph datasets. It provides a lightweight, in-process integration with applications, making it ideal for scenarios requiring fast and efficient graph data processing.
|
||||
|
||||
Kuzu has the following core features:
|
||||
|
||||
- **Property Graph Data Model and Cypher Query Language**: Supports the property graph model and uses Cypher, a powerful and expressive query language for graph databases.
|
||||
- **Embedded Integration**: Runs in-process with applications, eliminating the need for a separate server.
|
||||
- **Columnar Disk-Based Storage**: Utilizes columnar storage for efficient data access and management.
|
||||
- **Columnar and Compressed Sparse Row-Based (CSR) Adjacency List and Join Indices**: Optimizes storage and query performance for large graphs.
|
||||
- **Vectorized and Factorized Query Processing**: Enhances query execution speed through advanced processing techniques.
|
||||
- **Novel and Efficient Join Algorithms**: Improves performance for complex join operations.
|
||||
- **Multi-Core Query Parallelism**: Leverages multiple cores for faster query execution.
|
||||
- **Serializable ACID Transactions**: Ensures data consistency and reliability with full ACID compliance.
|
||||
|
||||
|
||||
## Available Tools
|
||||
|
||||
- [`kuzu-cypher`](../tools/kuzu/kuzu-cypher.md)
|
||||
Execute pre-defined Cypher queries with placeholder parameters.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database File
|
||||
|
||||
To use Kuzu, you can either:
|
||||
|
||||
- Specify a file path for a persistent database file stored on the filesystem
|
||||
- Omit the file path to use an in-memory database
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-kuzu-db:
|
||||
kind: "kuzu"
|
||||
database: "/path/to/database.db"
|
||||
bufferPoolSize: 1073741824 # 1GB
|
||||
maxNumThreads: 4
|
||||
enableCompression: true
|
||||
readOnly: false
|
||||
maxDbSize: 5368709120 # 5GB
|
||||
```
|
||||
|
||||
For an in-memory database:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-kuzu-memory-db:
|
||||
kind: "kuzu"
|
||||
bufferPoolSize: 1073741824 # 1GB
|
||||
maxNumThreads: 4
|
||||
enableCompression: true
|
||||
readOnly: false
|
||||
maxDbSize: 5368709120 # 5GB
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
### Configuration Fields
|
||||
|
||||
| **Field** | **Type** | **Required** | **Description** |
|
||||
|--------------------|:--------:|:------------:|---------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "kuzu". |
|
||||
| database | string | false | Path to the database directory. Default is ":memory:" which creates an in-memory database. |
|
||||
| bufferPoolSize | uint64 | false | Size of the buffer pool in bytes (e.g., 1073741824 for 1GB). |
|
||||
| maxNumThreads | uint64 | false | Maximum number of threads for query execution. |
|
||||
| enableCompression | bool | false | Enables or disables data compression. Default is true. |
|
||||
| readOnly | bool | false | Sets the database to read-only mode if true. Default is false. |
|
||||
| maxDbSize | uint64 | false | Maximum database size in bytes (e.g., 5368709120 for 5GB). |
|
||||
|
||||
For a complete list of available configuration options, refer to the [Kuzu SystemConfig options](https://pkg.go.dev/github.com/kuzudb/go-kuzu#SystemConfig).
|
||||
7
docs/en/resources/tools/kuzu/_index.md
Normal file
7
docs/en/resources/tools/kuzu/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Kuzu"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Kuzu Sources.
|
||||
---
|
||||
110
docs/en/resources/tools/kuzu/kuzu-cypher.md
Normal file
110
docs/en/resources/tools/kuzu/kuzu-cypher.md
Normal file
@@ -0,0 +1,110 @@
|
||||
---
|
||||
title: "kuzu-cypher"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "kuzu-cypher" tool executes a pre-defined cypher statement against a Kuzu
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/kuzu-cypher
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `kuzu-cypher` tool executes a pre-defined Cypher statement against a Kuzu graph database. It is designed to work with Kuzu's embedded graph database, optimized for high query speed and scalability. The tool is compatible with the following sources:
|
||||
|
||||
- [kuzu](../../sources/kuzu.md)
|
||||
|
||||
The specified Cypher statement is executed as a [parameterized statement][kuzu-parameters], with parameters referenced by their name (e.g., `$id`). This approach ensures security by preventing Cypher injection attacks.
|
||||
|
||||
> **Note:** This tool uses parameterized queries to prevent Cypher injections. \
|
||||
> Query parameters can be used as substitutes for arbitrary expressions but cannot replace identifiers, node labels, relationship types, or other structural parts of the query.
|
||||
|
||||
[kuzu-parameters]:
|
||||
https://docs.kuzudb.com/get-started/prepared-statements/
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
find_collaborators:
|
||||
kind: kuzu-cypher
|
||||
source: my-kuzu-social-network
|
||||
statement: |
|
||||
MATCH (p1:Person)-[:Collaborated_With]->(p2:Person)
|
||||
WHERE p1.name = $name AND p2.age > $min_age
|
||||
RETURN p2.name, p2.age
|
||||
LIMIT 10
|
||||
description: |
|
||||
Use this tool to find collaborators for a specific person in a social network, filtered by a minimum age.
|
||||
Takes a full person name (e.g., "Alice Smith") and a minimum age (e.g., 25) and returns a list of collaborator names and their ages.
|
||||
Do NOT use this tool with incomplete names or arbitrary values. Do NOT guess a name or age.
|
||||
A person name is a fully qualified name with first and last name separated by a space.
|
||||
For example, if given "Smith, Alice" the person name is "Alice Smith".
|
||||
If multiple results are returned, prioritize those with the closest collaboration ties.
|
||||
Example:
|
||||
{{
|
||||
"name": "Bob Johnson",
|
||||
"min_age": 30
|
||||
}}
|
||||
Example:
|
||||
{{
|
||||
"name": "Emma Davis",
|
||||
"min_age": 25
|
||||
}}
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: Full person name, "firstname lastname"
|
||||
- name: min_age
|
||||
type: integer
|
||||
description: Minimum age as a positive integer
|
||||
```
|
||||
### Example with Template Parameters
|
||||
|
||||
> **Note:** This tool allows direct modifications to the Cypher statement,
|
||||
> including identifiers, column names, and table names. **This makes it more
|
||||
> vulnerable to Cypher injections**. Using basic parameters only (see above) is
|
||||
> recommended for performance and safety reasons. For more details, please check
|
||||
> [templateParameters](../#template-parameters).
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
find_friends:
|
||||
kind: kuzu-cypher
|
||||
source: my-kuzu-social-network
|
||||
statement: |
|
||||
MATCH (p1:{{.nodeLabel}})-[:friends_with]->(p2:{{.nodeLabel}})
|
||||
WHERE p1.name = $name
|
||||
RETURN p2.name
|
||||
LIMIT 5
|
||||
description: |
|
||||
Use this tool to find friends of a specific person in a social network.
|
||||
Takes a node label (e.g., "Person") and a full person name (e.g., "Alice Smith") and returns a list of friend names.
|
||||
Do NOT use with incomplete names. A person name is a full name with first and last name separated by a space.
|
||||
Example:
|
||||
{
|
||||
"nodeLabel": "Person",
|
||||
"name": "Bob Johnson"
|
||||
}
|
||||
templateParameters:
|
||||
- name: nodeLabel
|
||||
type: string
|
||||
description: Node label for the table to query, e.g., "Person"
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: Full person name, "firstname lastname"
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **Field** | **Type** | **Required** | **Description** |
|
||||
|----------------------|:-------------------------------------:|:------------:|---------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "kuzu-cypher". |
|
||||
| source | string | true | Name of the Kuzu source the Cypher query should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM for context. |
|
||||
| statement | string | true | Cypher statement to execute. |
|
||||
| authRequired | []string | false | List of authentication requirements for executing the query (if applicable). |
|
||||
| parameters | [parameters](../#specifying-parameters) | false | List of parameters used with the Cypher statement. |
|
||||
| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the Cypher statement before executing prepared statement. |
|
||||
2
go.mod
2
go.mod
@@ -28,6 +28,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/kuzudb/go-kuzu v0.11.0
|
||||
github.com/looker-open-source/sdk-codegen/go v0.25.10
|
||||
github.com/microsoft/go-mssqldb v1.9.2
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.1
|
||||
@@ -119,6 +120,7 @@ require (
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1016,6 +1016,8 @@ github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NB
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kuzudb/go-kuzu v0.11.0 h1:7nH5zabXH+IBZruyyML6YIi4tayqg3diwbXmXmnZE8k=
|
||||
github.com/kuzudb/go-kuzu v0.11.0/go.mod h1:s2NvXX3fB2QZfWGf6SjJSYawgTPE17a7WHZmzfLIZtU=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
@@ -1081,6 +1083,8 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
||||
github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4=
|
||||
|
||||
122
internal/sources/kuzu/kuzu.go
Normal file
122
internal/sources/kuzu/kuzu.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kuzu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/kuzudb/go-kuzu"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
var SourceKind string = "kuzu"
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required" `
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Database string `yaml:"database" default:":memory:"`
|
||||
BufferPoolSize uint64 `yaml:"bufferPoolSize"`
|
||||
MaxNumThreads uint64 `yaml:"maxNumThreads"`
|
||||
EnableCompression bool `yaml:"enableCompression"`
|
||||
ReadOnly bool `yaml:"readOnly"`
|
||||
MaxDbSize uint64 `yaml:"maxDbSize"`
|
||||
}
|
||||
|
||||
func (cfg Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
conn, err := initKuzuConnection(ctx, tracer, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open a database connection: %w", err)
|
||||
}
|
||||
|
||||
source := &Source{
|
||||
Name: cfg.Name,
|
||||
Kind: SourceKind,
|
||||
Connection: conn,
|
||||
}
|
||||
return source, nil
|
||||
}
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Connection *kuzu.Connection
|
||||
}
|
||||
|
||||
// SourceKind implements sources.Source.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) KuzuDB() *kuzu.Connection {
|
||||
return s.Connection
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
func initKuzuConnection(ctx context.Context, tracer trace.Tracer, config Config) (*kuzu.Connection, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||
defer span.End()
|
||||
systemConfig := kuzu.DefaultSystemConfig()
|
||||
if config.BufferPoolSize != 0 {
|
||||
systemConfig.BufferPoolSize = config.BufferPoolSize
|
||||
}
|
||||
if config.EnableCompression {
|
||||
systemConfig.EnableCompression = config.EnableCompression
|
||||
}
|
||||
if config.MaxDbSize != 0 {
|
||||
systemConfig.MaxDbSize = config.MaxDbSize
|
||||
}
|
||||
if config.ReadOnly {
|
||||
systemConfig.ReadOnly = config.ReadOnly
|
||||
}
|
||||
if config.MaxNumThreads != 0 {
|
||||
systemConfig.MaxNumThreads = config.MaxNumThreads
|
||||
}
|
||||
|
||||
db, err := kuzu.OpenDatabase(config.Database, systemConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to database: %w", err)
|
||||
}
|
||||
|
||||
conn, err := kuzu.OpenConnection(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open a database connection: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
89
internal/sources/kuzu/kuzu_test.go
Normal file
89
internal/sources/kuzu/kuzu_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kuzu_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/kuzu"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlKuzu(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-kuzu-db:
|
||||
kind: kuzu
|
||||
database: /path/to/database.db
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-kuzu-db": kuzu.Config{
|
||||
Name: "my-kuzu-db",
|
||||
Kind: kuzu.SourceKind,
|
||||
Database: "/path/to/database.db",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with configuration",
|
||||
in: `
|
||||
sources:
|
||||
my-kuzu-db:
|
||||
kind: kuzu
|
||||
database: /path/to/database.db
|
||||
maxNumThreads: 10
|
||||
readOnly: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-kuzu-db": kuzu.Config{
|
||||
Name: "my-kuzu-db",
|
||||
Kind: kuzu.SourceKind,
|
||||
Database: "/path/to/database.db",
|
||||
MaxNumThreads: 10,
|
||||
ReadOnly: true,
|
||||
MaxDbSize: 0,
|
||||
BufferPoolSize: 0,
|
||||
EnableCompression: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
201
internal/tools/kuzu/kuzucypher/kuzucypher.go
Normal file
201
internal/tools/kuzu/kuzucypher/kuzucypher.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kuzucypher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
kuzuSource "github.com/googleapis/genai-toolbox/internal/sources/kuzu"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/kuzudb/go-kuzu"
|
||||
)
|
||||
|
||||
var kind string = "kuzu-cypher"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
KuzuDB() *kuzu.Connection
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &kuzuSource.Source{}
|
||||
var compatibleSources = [...]string{kuzuSource.SourceKind}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Connection: s.KuzuDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Connection *kuzu.Connection
|
||||
Statement string `yaml:"statement"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
conn := t.Connection
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
preparedStatement, err := conn.Prepare(newStatement)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate prepared statement %w", err)
|
||||
}
|
||||
newParamMap, err := getParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
|
||||
result, err := conn.Execute(preparedStatement, newParamMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer result.Close()
|
||||
cols := result.GetColumnNames()
|
||||
var out []any
|
||||
for result.HasNext() {
|
||||
tuple, err := result.Next()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
defer tuple.Close()
|
||||
|
||||
// The result is a tuple, which can be converted to a slice.
|
||||
slice, err := tuple.GetAsSlice()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to slice row: %w", err)
|
||||
}
|
||||
rowMap := make(map[string]interface{})
|
||||
for i, col := range cols {
|
||||
val := slice[i]
|
||||
// Store the value in the map
|
||||
rowMap[col] = val
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Manifest implements tools.Tool.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest implements tools.Tool.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// ParseParams implements tools.Tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claimsMap)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func getParams(params tools.Parameters, paramValuesMap map[string]interface{}) (map[string]interface{}, error) {
|
||||
newParamMap := make(map[string]any)
|
||||
for _, p := range params {
|
||||
k := p.GetName()
|
||||
v, ok := paramValuesMap[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing parameter %s", k)
|
||||
}
|
||||
newParamMap[k] = v
|
||||
}
|
||||
return newParamMap, nil
|
||||
}
|
||||
94
internal/tools/kuzu/kuzucypher/kuzucypher_test.go
Normal file
94
internal/tools/kuzu/kuzucypher/kuzucypher_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kuzucypher_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/kuzu/kuzucypher"
|
||||
)
|
||||
|
||||
func TestParseFromYamlKuzu(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: kuzu-cypher
|
||||
source: my-kuzu-instance
|
||||
description: some description
|
||||
statement: |
|
||||
match (a:user {name:$name}) return a.*;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": kuzucypher.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "kuzu-cypher",
|
||||
Source: "my-kuzu-instance",
|
||||
Description: "some description",
|
||||
Statement: "match (a:user {name:$name}) return a.*;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("name", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
442
tests/kuzu/kuzu_integration_test.go
Normal file
442
tests/kuzu/kuzu_integration_test.go
Normal file
@@ -0,0 +1,442 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kuzu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/kuzudb/go-kuzu"
|
||||
)
|
||||
|
||||
var (
|
||||
database = "/tmp/example.kuzu"
|
||||
toolKind = "kuzu-cypher"
|
||||
sourceKind = "kuzu"
|
||||
)
|
||||
|
||||
func getSourceConfig() map[string]any {
|
||||
return map[string]any{
|
||||
"name": sourceKind,
|
||||
"kind": sourceKind,
|
||||
"database": database,
|
||||
"maxNumThreads": 10,
|
||||
}
|
||||
}
|
||||
func initKuzuDbConnection() error {
|
||||
queries := []string{
|
||||
"create node table user(name string primary key, age int64, email string)",
|
||||
"create node table city(name string primary key, population int64)",
|
||||
"create rel table follows(from user to user, since int64)",
|
||||
"create rel table livesin(from user to city)",
|
||||
fmt.Sprintf("create (u:user {name:'Alice', age:20, email: %q})", tests.ServiceAccountEmail),
|
||||
"create (u:user {name:'Jane', age:30, email: 'janedoe@gmail.com'})",
|
||||
"create (u:city {name:'London', population:100})",
|
||||
"create (u:city {name:'New York', population:200})",
|
||||
"match (u1:user), (u2:user) where u1.name='Alice' and u2.name='Jane' create (u1)-[:follows {since: 2019}]->(u2)",
|
||||
"match (u:user), (c:city) where u.name='Alice' and c.name='New York' create (u)-[:livesin]->(c)",
|
||||
}
|
||||
database, err := kuzu.OpenDatabase(database, kuzu.DefaultSystemConfig())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := kuzu.OpenConnection(database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, q := range queries {
|
||||
_, err := conn.Query(q)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestKuzuDbToolEndpoints(t *testing.T) {
|
||||
err := initKuzuDbConnection()
|
||||
if err != nil {
|
||||
t.Fatalf("could not create kuzudb connection")
|
||||
}
|
||||
defer os.Remove(database)
|
||||
defer os.Remove(fmt.Sprintf("%s.lock", database))
|
||||
defer os.Remove(fmt.Sprintf("%s.wal", database))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var args []string
|
||||
|
||||
paramToolStatement, paramToolStatement2, authToolStatement := createParamQueries()
|
||||
templateParamToolStmt, templateParamToolStmt2 := createTemplateQueries()
|
||||
toolsFile := getToolConfig(paramToolStatement, paramToolStatement2, authToolStatement, templateParamToolStmt, templateParamToolStmt2)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
tests.RunToolGetTest(t)
|
||||
runToolInvokeTest(t)
|
||||
runToolInvokeWithTemplateParameters(t, "user")
|
||||
}
|
||||
|
||||
func createParamQueries() (string, string, string) {
|
||||
toolStatement := "match (u:user {name:$name}) return u.age, u.name"
|
||||
toolStatement2 := "match (a:user)-[:follows {since:$year}]->(b:user) return a.name, b.name"
|
||||
authToolStatement := "match (u:user {email:$email}) return u.age, u.name"
|
||||
return toolStatement, toolStatement2, authToolStatement
|
||||
}
|
||||
func createTemplateQueries() (string, string) {
|
||||
toolStatement := "match (u:{{.tableName}} {name:$name}) return u.age, u.name"
|
||||
toolStatement2 := "match (a:{{.tableName}})-[:follows { {{.edgeFilter}} :$year}]->(b:user) return a.name, b.name"
|
||||
return toolStatement, toolStatement2
|
||||
}
|
||||
|
||||
func getToolConfig(paramToolStatement, paramToolStatement2, authToolStatement, templateParamToolStmt, templateParamToolStmt2 string) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": getSourceConfig(),
|
||||
},
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": tests.ClientId,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "Match (a) return a.name order by a.name;",
|
||||
},
|
||||
"my-param-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": paramToolStatement,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": paramToolStatement2,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "year",
|
||||
"type": "integer",
|
||||
"description": "year since when one user follows the other user",
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-fail-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test statement with incorrect syntax.",
|
||||
"statement": "SELEC 1;",
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test authenticated parameters.",
|
||||
// statement to auto-fill authenticated parameter
|
||||
"statement": authToolStatement,
|
||||
"parameters": []map[string]any{
|
||||
{
|
||||
"name": "email",
|
||||
"type": "string",
|
||||
"description": "user email",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth",
|
||||
"field": "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-auth-required-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test auth required invocation.",
|
||||
"statement": "MATCH (a) return a;",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
},
|
||||
"select-fields-templateParams-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with template params.",
|
||||
"statement": templateParamToolStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
},
|
||||
"select-filter-templateParams-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with template param filter.",
|
||||
"statement": templateParamToolStmt2,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "year",
|
||||
"type": "integer",
|
||||
"description": "year since when one user follows the other user",
|
||||
},
|
||||
},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewStringParameter("edgeFilter", "some description"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
func runToolInvokeTest(t *testing.T) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"a.name\":\"Alice\"},{\"a.name\":\"Jane\"},{\"a.name\":\"London\"},{\"a.name\":\"New York\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"name": "Alice"}`)),
|
||||
want: "[{\"u.age\":20,\"u.name\":\"Alice\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool2",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool2/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"year": 2019}`)),
|
||||
want: "[{\"a.name\":\"Alice\",\"b.name\":\"Jane\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool2 with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool2/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"year": 2020}`)),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-param-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"u.age\":20,\"u.name\":\"Alice\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runToolInvokeWithTemplateParameters(t *testing.T, tableName string) {
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
ddl bool
|
||||
insert bool
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
ddl: true,
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "name": "Alice"}`, tableName))),
|
||||
want: "[{\"u.age\":20,\"u.name\":\"Alice\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-filter-templateParams-tool",
|
||||
insert: true,
|
||||
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "edgeFilter": "since", "year":2019}`, tableName))),
|
||||
want: "[{\"a.name\":\"Alice\",\"b.name\":\"Jane\"}]",
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user