mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
@@ -178,6 +178,21 @@ steps:
|
||||
- |
|
||||
go test -race -v -tags=integration,mysql ./tests
|
||||
|
||||
- id: "dgraph"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "DGRAPH_URL=$_DGRAPHURL"
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,dgraph ./tests
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -233,3 +248,4 @@ substitutions:
|
||||
_CLOUD_SQL_MYSQL_INSTANCE: "cloud-sql-mysql-testing"
|
||||
_MYSQL_HOST: 127.0.0.1
|
||||
_MYSQL_PORT: "3306"
|
||||
_DGRAPHURL: "https://play.dgraph.io"
|
||||
|
||||
55
docs/en/resources/sources/dgraph.md
Normal file
55
docs/en/resources/sources/dgraph.md
Normal file
@@ -0,0 +1,55 @@
|
||||
---
|
||||
title: "Dgraph"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Dgraph is a horizontally scalable and distributed graph database.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[Dgraph][dgraph-docs] is a horizontally scalable and distributed graph database.
|
||||
It provides ACID transactions, consistent replication, and linearizable reads.
|
||||
|
||||
This source can connect to either a self-managed Dgraph cluster or one hosted on
|
||||
Dgraph Cloud. If you're new to Dgraph, the fastest way to get started is to
|
||||
[sign up for Dgraph Cloud][dgraph-login].
|
||||
|
||||
[dgraph-docs]: https://dgraph.io/docs
|
||||
[dgraph-login]: https://cloud.dgraph.io/login
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
When **connecting to a hosted Dgraph database**, this source uses the API key
|
||||
for access. If you are using a dedicated environment, you will additionally need
|
||||
the namespace and user credentials for that namespace.
|
||||
|
||||
For **connecting to a local or self-hosted Dgraph database**, use the namespace
|
||||
and user credentials for that namespace.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-dgraph-source:
|
||||
kind: "dgraph"
|
||||
dgraphUrl: "https://xxxx.cloud.dgraph.io"
|
||||
user: "groot"
|
||||
password: "password"
|
||||
apiKey: abc123
|
||||
namepace : 0
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **Field** | **Type** | **Required** | **Description** |
|
||||
|-------------|:--------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "dgraph". |
|
||||
| dgraphUrl | string | true | Connection URI (e.g. "https://xxx.cloud.dgraph.io", "https://localhost:8080"). |
|
||||
| user | string | false | Name of the Dgraph user to connect as (e.g., "groot"). |
|
||||
| password | string | false | Password of the Dgraph user (e.g., "password"). |
|
||||
| apiKey | string | false | API key to connect to a Dgraph Cloud instance. |
|
||||
| namespace | uint64 | false | Dgraph namespace (not required for Dgraph Cloud Shared Clusters). |
|
||||
@@ -52,6 +52,8 @@ We currently support the following types of kinds of tools:
|
||||
statement againts Spanner database.
|
||||
* [neo4j-cypher](./neo4j-cypher.md) - Run a Cypher statement against a
|
||||
Neo4j database.
|
||||
* [dgraph-dql](./dgraph-dql.md) - Run a DQL statement against a
|
||||
Dgraph database.
|
||||
|
||||
|
||||
## Specifying Parameters
|
||||
|
||||
105
docs/tools/dgraph-dql.md
Normal file
105
docs/tools/dgraph-dql.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Dgraph DQL Tool
|
||||
|
||||
|
||||
A "dgraph-dql" tool executes a pre-defined DQL statement against a Dgraph database. It's compatible with any of the following
|
||||
sources:
|
||||
- [dgraph](../sources/dgraph.md)
|
||||
|
||||
To run a statement as a query, you need to set the config isQuery=true. For upserts or mutations, set isQuery=false.
|
||||
You can also configure timeout for a query.
|
||||
|
||||
## Example
|
||||
|
||||
### Query:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_user:
|
||||
kind: dgraph-dql
|
||||
source: my-dgraph-source
|
||||
statement: |
|
||||
query all($role: string){
|
||||
users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) {
|
||||
uid
|
||||
name
|
||||
email
|
||||
role
|
||||
age
|
||||
}
|
||||
}
|
||||
isQuery: true
|
||||
timeout: 20s
|
||||
description: |
|
||||
Use this tool to retrieve the details of users who are admins and are between 30 and 50 years old.
|
||||
The query returns the user's name, email, role, and age.
|
||||
This can be helpful when you want to fetch admin users within a specific age range.
|
||||
Example: Fetch admins aged between 30 and 50:
|
||||
[
|
||||
{
|
||||
"name": "Alice",
|
||||
"role": "admin",
|
||||
"age": 35
|
||||
},
|
||||
{
|
||||
"name": "Bob",
|
||||
"role": "admin",
|
||||
"age": 45
|
||||
}
|
||||
]
|
||||
parameters:
|
||||
- name: $role
|
||||
type: string
|
||||
description: admin
|
||||
```
|
||||
|
||||
### Mutation:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
dgraph-manage-user-instance:
|
||||
kind: dgraph-dql
|
||||
source: my-dgraph-source
|
||||
isQuery: false
|
||||
statement: |
|
||||
{
|
||||
set {
|
||||
_:user1 <name> $user1 .
|
||||
_:user1 <email> $email1 .
|
||||
_:user1 <role> "admin" .
|
||||
_:user1 <age> "35" .
|
||||
|
||||
_:user2 <name> $user2 .
|
||||
_:user2 <email> $email2 .
|
||||
_:user2 <role> "admin" .
|
||||
_:user2 <age> "45" .
|
||||
}
|
||||
}
|
||||
description: |
|
||||
Use this tool to insert or update user data into the Dgraph database.
|
||||
The mutation adds or updates user details like name, email, role, and age.
|
||||
Example: Add users Alice and Bob as admins with specific ages.
|
||||
parameters:
|
||||
- name: user1
|
||||
type: string
|
||||
description: Alice
|
||||
- name: email1
|
||||
type: string
|
||||
description: alice@email.com
|
||||
- name: user2
|
||||
type: string
|
||||
description: Bob
|
||||
- name: email2
|
||||
type: string
|
||||
description: bob@email.com
|
||||
```
|
||||
|
||||
## Reference
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|----------:|:------------:|----------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "dgraph-dql". |
|
||||
| source | string | true | Name of the source the dql query should execute on. |
|
||||
| description | string | true | Description of the tool |
|
||||
| statement | string | true | dql statement to execute |
|
||||
| isQuery | boolean | false | To run statment as query set true otherwise false |
|
||||
| timeout | string | false | To set timout for query |
|
||||
| parameters | parameter | true | List of [parameters](README.md#specifying-parameters) that will be used with the dql statement. |
|
||||
2
go.mod
2
go.mod
@@ -52,6 +52,7 @@ require (
|
||||
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/envoyproxy/go-control-plane v0.13.1 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
@@ -69,6 +70,7 @@ require (
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -699,8 +699,9 @@ github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1Ig
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -951,8 +952,9 @@ github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZ
|
||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
|
||||
|
||||
@@ -25,12 +25,14 @@ import (
|
||||
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
dgraphsrc "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
|
||||
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
@@ -195,6 +197,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case dgraphsrc.SourceKind:
|
||||
actual := dgraphsrc.Config{Name: name}
|
||||
if err := u.Unmarshal(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
@@ -292,6 +300,12 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case dgraph.ToolKind:
|
||||
actual := dgraph.Config{Name: name}
|
||||
if err := u.Unmarshal(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
379
internal/sources/dgraph/dgraph.go
Normal file
379
internal/sources/dgraph/dgraph.go
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dgraph
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "dgraph"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
// HttpToken stores credentials for making HTTP request
|
||||
type HttpToken struct {
|
||||
UserId string
|
||||
Password string
|
||||
AccessJwt string
|
||||
RefreshToken string
|
||||
Namespace uint64
|
||||
}
|
||||
|
||||
type DgraphClient struct {
|
||||
httpClient *http.Client
|
||||
*HttpToken
|
||||
baseUrl string
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
DgraphUrl string `yaml:"dgraphUrl"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Namespace uint64 `yaml:"namespace"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
hc, err := initDgraphHttpClient(ctx, tracer, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := hc.healthCheck(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: hc,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *DgraphClient `yaml:"client"`
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) DgraphClient() *DgraphClient {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
|
||||
defer span.End()
|
||||
|
||||
if r.DgraphUrl == "" {
|
||||
return nil, fmt.Errorf("dgraph url should not be empty")
|
||||
}
|
||||
|
||||
hc := &DgraphClient{
|
||||
httpClient: &http.Client{},
|
||||
baseUrl: r.DgraphUrl,
|
||||
HttpToken: &HttpToken{
|
||||
UserId: r.User,
|
||||
Namespace: r.Namespace,
|
||||
Password: r.Password,
|
||||
},
|
||||
apiKey: r.ApiKey,
|
||||
}
|
||||
|
||||
if r.User != "" || r.Password != "" {
|
||||
if err := hc.loginWithCredentials(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) ExecuteQuery(query string, paramsMap map[string]interface{},
|
||||
isQuery bool, timeout string) ([]byte, error) {
|
||||
if isQuery {
|
||||
return hc.postDqlQuery(query, paramsMap, timeout)
|
||||
} else {
|
||||
return hc.mutate(query, paramsMap)
|
||||
}
|
||||
}
|
||||
|
||||
// postDqlQuery sends a DQL query to the Dgraph server with query, parameters, and optional timeout.
|
||||
// Returns the response body ([]byte) and an error, if any.
|
||||
func (hc *DgraphClient) postDqlQuery(query string, paramsMap map[string]interface{}, timeout string) ([]byte, error) {
|
||||
urlParams := url.Values{}
|
||||
urlParams.Add("timeout", timeout)
|
||||
url, err := getUrl(hc.baseUrl, "/query", urlParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := struct {
|
||||
Query string `json:"query"`
|
||||
Variables map[string]interface{} `json:"variables"`
|
||||
}{
|
||||
Query: query,
|
||||
Variables: paramsMap,
|
||||
}
|
||||
body, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshlling json: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
return hc.doReq(req)
|
||||
}
|
||||
|
||||
// mutate sends an RDF mutation to the Dgraph server with "commitNow: true", embedding parameters.
|
||||
// Returns the server's response as a byte slice or an error if the mutation fails.
|
||||
func (hc *DgraphClient) mutate(mutation string, paramsMap map[string]interface{}) ([]byte, error) {
|
||||
mu := embedParamsIntoMutation(mutation, paramsMap)
|
||||
params := url.Values{}
|
||||
params.Add("commitNow", "true")
|
||||
url, err := getUrl(hc.baseUrl, "/mutate", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(mu))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/rdf")
|
||||
|
||||
return hc.doReq(req)
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) doReq(req *http.Request) ([]byte, error) {
|
||||
if hc.HttpToken != nil {
|
||||
req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt)
|
||||
}
|
||||
if hc.apiKey != "" {
|
||||
req.Header.Set("Dg-Auth", hc.apiKey)
|
||||
}
|
||||
|
||||
resp, err := hc.httpClient.Do(req)
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "Token is expired") {
|
||||
return nil, fmt.Errorf("error performing HTTP request: %w", err)
|
||||
} else if err != nil && strings.Contains(err.Error(), "Token is expired") {
|
||||
if errLogin := hc.loginWithToken(); errLogin != nil {
|
||||
return nil, errLogin
|
||||
}
|
||||
if hc.HttpToken != nil {
|
||||
req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt)
|
||||
}
|
||||
resp, err = hc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: url: [%v], err: [%v]", req.URL, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("got non 200 resp: %v", string(respBody))
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) loginWithCredentials() error {
|
||||
credentials := map[string]interface{}{
|
||||
"userid": hc.UserId,
|
||||
"password": hc.Password,
|
||||
"namespace": hc.Namespace,
|
||||
}
|
||||
return hc.doLogin(credentials)
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) loginWithToken() error {
|
||||
credentials := map[string]interface{}{
|
||||
"refreshJWT": hc.RefreshToken,
|
||||
"namespace": hc.Namespace,
|
||||
}
|
||||
return hc.doLogin(credentials)
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) doLogin(creds map[string]interface{}) error {
|
||||
url, err := getUrl(hc.baseUrl, "/login", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := json.Marshal(creds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %v", err)
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error building req for endpoint [%v] : %v", url, err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
if hc.apiKey != "" {
|
||||
req.Header.Set("Dg-Auth", hc.apiKey)
|
||||
}
|
||||
|
||||
resp, err := hc.doReq(req)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "Token is expired") &&
|
||||
!strings.Contains(err.Error(), "unable to authenticate the refresh token") {
|
||||
return hc.loginWithToken()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := CheckError(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var r struct {
|
||||
Data struct {
|
||||
AccessJWT string `json:"accessJWT"`
|
||||
RefreshJWT string `json:"refreshJWT"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp, &r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if r.Data.AccessJWT == "" {
|
||||
return fmt.Errorf("no access JWT found in the response")
|
||||
}
|
||||
if r.Data.RefreshJWT == "" {
|
||||
return fmt.Errorf("no refresh JWT found in the response")
|
||||
}
|
||||
|
||||
hc.AccessJwt = r.Data.AccessJWT
|
||||
hc.RefreshToken = r.Data.RefreshJWT
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hc *DgraphClient) healthCheck() error {
|
||||
url, err := getUrl(hc.baseUrl, "/health", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := hc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error performing request: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var result []struct {
|
||||
Instance string `json:"instance"`
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// Unmarshal response into the struct
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal json: %v", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return fmt.Errorf("health info should not empty for: %v", url)
|
||||
}
|
||||
|
||||
var unhealthyErr error
|
||||
for _, info := range result {
|
||||
if info.Status != "healthy" {
|
||||
unhealthyErr = fmt.Errorf("dgraph instance [%v] is not in healthy state, address is %v",
|
||||
info.Instance, info.Address)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return unhealthyErr
|
||||
}
|
||||
|
||||
func getUrl(baseUrl, resource string, params url.Values) (string, error) {
|
||||
u, err := url.ParseRequestURI(baseUrl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get url %v", err)
|
||||
}
|
||||
u.Path = resource
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func CheckError(resp []byte) error {
|
||||
var errResp struct {
|
||||
Errors []struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"errors"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp, &errResp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal json: %v", err)
|
||||
}
|
||||
|
||||
if len(errResp.Errors) > 0 {
|
||||
return fmt.Errorf("error : %v", errResp.Errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func embedParamsIntoMutation(mutation string, paramsMap map[string]interface{}) string {
|
||||
for key, value := range paramsMap {
|
||||
mutation = strings.ReplaceAll(mutation, key, fmt.Sprintf(`"%v"`, value))
|
||||
}
|
||||
return mutation
|
||||
}
|
||||
76
internal/sources/dgraph/dgraph_test.go
Normal file
76
internal/sources/dgraph/dgraph_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dgraph_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlDgraph(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-dgraph-instance:
|
||||
kind: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
apiKey: abc123
|
||||
password: pass@123
|
||||
namespace: 0
|
||||
user: user123
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-dgraph-instance": dgraph.Config{
|
||||
Name: "my-dgraph-instance",
|
||||
Kind: dgraph.SourceKind,
|
||||
DgraphUrl: "https://localhost:8080",
|
||||
ApiKey: "abc123",
|
||||
Password: "pass@123",
|
||||
Namespace: 0,
|
||||
User: "user123",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
133
internal/tools/dgraph/dgraph.go
Normal file
133
internal/tools/dgraph/dgraph.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dgraph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const ToolKind string = "dgraph-dql"
|
||||
|
||||
type compatibleSource interface {
|
||||
DgraphClient() *dgraph.DgraphClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dgraph.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dgraph.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
IsQuery bool `yaml:"isQuery"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
DgraphClient: s.DgraphClient(),
|
||||
IsQuery: cfg.IsQuery,
|
||||
Timeout: cfg.Timeout,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
DgraphClient *dgraph.DgraphClient
|
||||
IsQuery bool
|
||||
Timeout string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
paramsMap := params.AsMapWithDollarPrefix()
|
||||
|
||||
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := dgraph.CheckError(resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return "", fmt.Errorf("error parsing JSON: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Stub tool call for %q! Parameters parsed: %q \n Output: %v",
|
||||
t.Name, paramsMap, result.Data,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
96
internal/tools/dgraph/dgraph_test.go
Normal file
96
internal/tools/dgraph/dgraph_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dgraph_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
)
|
||||
|
||||
func TestParseFromYamlDgraph(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic query example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: dgraph-dql
|
||||
source: my-dgraph-instance
|
||||
description: some tool description
|
||||
isQuery: true
|
||||
timeout: 20s
|
||||
statement: |
|
||||
query {q(func: eq(email, "example@email.com")) {email}}
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": dgraph.Config{
|
||||
Name: "example_tool",
|
||||
Kind: dgraph.ToolKind,
|
||||
Source: "my-dgraph-instance",
|
||||
Description: "some tool description",
|
||||
IsQuery: true,
|
||||
Timeout: "20s",
|
||||
Statement: "query {q(func: eq(email, \"example@email.com\")) {email}}\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic mutation example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: dgraph-dql
|
||||
source: my-dgraph-instance
|
||||
description: some tool description
|
||||
statement: |
|
||||
mutation {set { _:a <name> "a@email.com" . _:b <email> "b@email.com" .}}
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": dgraph.Config{
|
||||
Name: "example_tool",
|
||||
Kind: dgraph.ToolKind,
|
||||
Source: "my-dgraph-instance",
|
||||
Description: "some tool description",
|
||||
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -17,6 +17,7 @@ package tools
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
@@ -78,6 +79,23 @@ func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
|
||||
return params
|
||||
}
|
||||
|
||||
// AsMapWithDollarPrefix ensures all keys are prefixed with a dollar sign for Dgraph.
|
||||
// Example:
|
||||
// Input: {"role": "admin", "$age": 30}
|
||||
// Output: {"$role": "admin", "$age": 30}
|
||||
func (p ParamValues) AsMapWithDollarPrefix() map[string]interface{} {
|
||||
params := make(map[string]interface{})
|
||||
|
||||
for _, param := range p {
|
||||
key := param.Name
|
||||
if !strings.HasPrefix(key, "$") {
|
||||
key = "$" + key
|
||||
}
|
||||
params[key] = param.Value
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) {
|
||||
// parse a parameter from claims using its specified auth sources
|
||||
for _, a := range paramAuthSources {
|
||||
|
||||
@@ -613,6 +613,7 @@ func TestParamValues(t *testing.T) {
|
||||
wantSlice []any
|
||||
wantMap map[string]interface{}
|
||||
wantMapOrdered map[string]interface{}
|
||||
wantMapWithDollar map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
@@ -620,6 +621,10 @@ func TestParamValues(t *testing.T) {
|
||||
wantSlice: []any{true, "hello world"},
|
||||
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
|
||||
wantMapOrdered: map[string]interface{}{"p1": true, "p2": "hello world"},
|
||||
wantMapWithDollar: map[string]interface{}{
|
||||
"$my_bool": true,
|
||||
"$my_string": "hello world",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
@@ -627,6 +632,7 @@ func TestParamValues(t *testing.T) {
|
||||
gotSlice := tc.in.AsSlice()
|
||||
gotMap := tc.in.AsMap()
|
||||
gotMapOrdered := tc.in.AsMapByOrderedKeys()
|
||||
gotMapWithDollar := tc.in.AsMapWithDollarPrefix()
|
||||
|
||||
for i, got := range gotSlice {
|
||||
want := tc.wantSlice[i]
|
||||
@@ -646,6 +652,12 @@ func TestParamValues(t *testing.T) {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
for key, got := range gotMapWithDollar {
|
||||
want := tc.wantMapWithDollar[key]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value in AsMapWithDollarPrefix: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
32
tests/dgraph.yaml
Normal file
32
tests/dgraph.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
|
||||
sources:
|
||||
dgraph-manage-user-instance:
|
||||
kind: "dgraph"
|
||||
dgraphUrl: "https://play.dgraph.io/query?latency=true"
|
||||
apiKey: ""
|
||||
|
||||
tools:
|
||||
search_user:
|
||||
kind: dgraph-dql
|
||||
source: dgraph-manage-user-instance
|
||||
statement: |
|
||||
query all($role: string){
|
||||
users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) {
|
||||
uid
|
||||
name
|
||||
email
|
||||
role
|
||||
age
|
||||
}
|
||||
}
|
||||
isQuery: true
|
||||
timeout: 20s
|
||||
description: |
|
||||
Use this tool to insert or update user data into the Dgraph database.
|
||||
The mutation adds or updates user details like name, email, role, and age.
|
||||
Example: Add users Alice and Bob as admins with specific ages.
|
||||
parameters:
|
||||
- name: role
|
||||
type: string
|
||||
description: admin
|
||||
|
||||
170
tests/dgraph_integration_test.go
Normal file
170
tests/dgraph_integration_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
//go:build integration && dgraph
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
DGRAPH_URL = os.Getenv("DGRAPH_URL")
|
||||
)
|
||||
|
||||
func requireDgraphVars(t *testing.T) {
|
||||
if DGRAPH_URL =="" {
|
||||
t.Fatal("'DGRAPH_URL' not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDgraph(t *testing.T) {
|
||||
requireDgraphVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-dgraph-instance": map[string]any{
|
||||
"kind": "dgraph",
|
||||
"dgraphUrl": DGRAPH_URL,
|
||||
"apiKey": "api-key",
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-dql-tool": map[string]any{
|
||||
"kind": "dgraph-dql",
|
||||
"source": "my-dgraph-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "{result(func: uid(0x0)) {constant: math(1)}}",
|
||||
"isQuery": true,
|
||||
"timeout": "20s",
|
||||
"parameters": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-dql-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(tc.api)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-dql-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "Stub tool call for \"my-simple-dql-tool\"! Parameters parsed: map[]" +
|
||||
" \n Output: map[result:[map[constant:1]]]",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user