feat: add dgraph tool and source (#233)

* add dgraph tool and source
This commit is contained in:
Shivaji Kharse
2025-02-02 04:02:06 +05:30
committed by GitHub
parent 8fca0a95ee
commit 617cc872d1
15 changed files with 1114 additions and 2 deletions

View File

@@ -178,6 +178,21 @@ steps:
- |
go test -race -v -tags=integration,mysql ./tests
- id: "dgraph"
name: golang:1
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "DGRAPH_URL=$_DGRAPHURL"
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
go test -race -v -tags=integration,dgraph ./tests
availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
@@ -233,3 +248,4 @@ substitutions:
_CLOUD_SQL_MYSQL_INSTANCE: "cloud-sql-mysql-testing"
_MYSQL_HOST: 127.0.0.1
_MYSQL_PORT: "3306"
_DGRAPHURL: "https://play.dgraph.io"

View File

@@ -0,0 +1,55 @@
---
title: "Dgraph"
type: docs
weight: 1
description: >
Dgraph is a horizontally scalable and distributed graph database.
---
## About
[Dgraph][dgraph-docs] is a horizontally scalable and distributed graph database.
It provides ACID transactions, consistent replication, and linearizable reads.
This source can connect to either a self-managed Dgraph cluster or one hosted on
Dgraph Cloud. If you're new to Dgraph, the fastest way to get started is to
[sign up for Dgraph Cloud][dgraph-login].
[dgraph-docs]: https://dgraph.io/docs
[dgraph-login]: https://cloud.dgraph.io/login
## Requirements
### Database User
When **connecting to a hosted Dgraph database**, this source uses the API key
for access. If you are using a dedicated environment, you will additionally need
the namespace and user credentials for that namespace.
For **connecting to a local or self-hosted Dgraph database**, use the namespace
and user credentials for that namespace.
## Example
```yaml
sources:
my-dgraph-source:
kind: "dgraph"
dgraphUrl: "https://xxxx.cloud.dgraph.io"
user: "groot"
password: "password"
apiKey: abc123
namepace : 0
```
## Reference
| **Field** | **Type** | **Required** | **Description** |
|-------------|:--------:|:------------:|--------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "dgraph". |
| dgraphUrl | string | true | Connection URI (e.g. "https://xxx.cloud.dgraph.io", "https://localhost:8080"). |
| user | string | false | Name of the Dgraph user to connect as (e.g., "groot"). |
| password | string | false | Password of the Dgraph user (e.g., "password"). |
| apiKey | string | false | API key to connect to a Dgraph Cloud instance. |
| namespace | uint64 | false | Dgraph namespace (not required for Dgraph Cloud Shared Clusters). |

View File

@@ -52,6 +52,8 @@ We currently support the following types of kinds of tools:
statement againts Spanner database.
* [neo4j-cypher](./neo4j-cypher.md) - Run a Cypher statement against a
Neo4j database.
* [dgraph-dql](./dgraph-dql.md) - Run a DQL statement against a
Dgraph database.
## Specifying Parameters

105
docs/tools/dgraph-dql.md Normal file
View File

@@ -0,0 +1,105 @@
# Dgraph DQL Tool
A "dgraph-dql" tool executes a pre-defined DQL statement against a Dgraph database. It's compatible with any of the following
sources:
- [dgraph](../sources/dgraph.md)
To run a statement as a query, you need to set the config isQuery=true. For upserts or mutations, set isQuery=false.
You can also configure timeout for a query.
## Example
### Query:
```yaml
tools:
search_user:
kind: dgraph-dql
source: my-dgraph-source
statement: |
query all($role: string){
users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) {
uid
name
email
role
age
}
}
isQuery: true
timeout: 20s
description: |
Use this tool to retrieve the details of users who are admins and are between 30 and 50 years old.
The query returns the user's name, email, role, and age.
This can be helpful when you want to fetch admin users within a specific age range.
Example: Fetch admins aged between 30 and 50:
[
{
"name": "Alice",
"role": "admin",
"age": 35
},
{
"name": "Bob",
"role": "admin",
"age": 45
}
]
parameters:
- name: $role
type: string
description: admin
```
### Mutation:
```yaml
tools:
dgraph-manage-user-instance:
kind: dgraph-dql
source: my-dgraph-source
isQuery: false
statement: |
{
set {
_:user1 <name> $user1 .
_:user1 <email> $email1 .
_:user1 <role> "admin" .
_:user1 <age> "35" .
_:user2 <name> $user2 .
_:user2 <email> $email2 .
_:user2 <role> "admin" .
_:user2 <age> "45" .
}
}
description: |
Use this tool to insert or update user data into the Dgraph database.
The mutation adds or updates user details like name, email, role, and age.
Example: Add users Alice and Bob as admins with specific ages.
parameters:
- name: user1
type: string
description: Alice
- name: email1
type: string
description: alice@email.com
- name: user2
type: string
description: Bob
- name: email2
type: string
description: bob@email.com
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|----------:|:------------:|----------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "dgraph-dql". |
| source | string | true | Name of the source the dql query should execute on. |
| description | string | true | Description of the tool |
| statement | string | true | dql statement to execute |
| isQuery | boolean | false | To run statment as query set true otherwise false |
| timeout | string | false | To set timout for query |
| parameters | parameter | true | List of [parameters](README.md#specifying-parameters) that will be used with the dql statement. |

2
go.mod
View File

@@ -52,6 +52,7 @@ require (
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/envoyproxy/go-control-plane v0.13.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -69,6 +70,7 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect

6
go.sum
View File

@@ -699,8 +699,9 @@ github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1Ig
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -951,8 +952,9 @@ github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZ
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=

View File

@@ -25,12 +25,14 @@ import (
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
dgraphsrc "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
@@ -195,6 +197,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case dgraphsrc.SourceKind:
actual := dgraphsrc.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}
@@ -292,6 +300,12 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case dgraph.ToolKind:
actual := dgraph.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
}

View File

@@ -0,0 +1,379 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dgraph
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "dgraph"
// validate interface
var _ sources.SourceConfig = Config{}
// HttpToken stores credentials for making HTTP request
type HttpToken struct {
UserId string
Password string
AccessJwt string
RefreshToken string
Namespace uint64
}
type DgraphClient struct {
httpClient *http.Client
*HttpToken
baseUrl string
apiKey string
}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
DgraphUrl string `yaml:"dgraphUrl"`
User string `yaml:"user"`
Password string `yaml:"password"`
Namespace uint64 `yaml:"namespace"`
ApiKey string `yaml:"apiKey"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
hc, err := initDgraphHttpClient(ctx, tracer, r)
if err != nil {
return nil, err
}
if err := hc.healthCheck(); err != nil {
return nil, err
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Client: hc,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Client *DgraphClient `yaml:"client"`
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) DgraphClient() *DgraphClient {
return s.Client
}
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
defer span.End()
if r.DgraphUrl == "" {
return nil, fmt.Errorf("dgraph url should not be empty")
}
hc := &DgraphClient{
httpClient: &http.Client{},
baseUrl: r.DgraphUrl,
HttpToken: &HttpToken{
UserId: r.User,
Namespace: r.Namespace,
Password: r.Password,
},
apiKey: r.ApiKey,
}
if r.User != "" || r.Password != "" {
if err := hc.loginWithCredentials(); err != nil {
return nil, err
}
}
return hc, nil
}
func (hc *DgraphClient) ExecuteQuery(query string, paramsMap map[string]interface{},
isQuery bool, timeout string) ([]byte, error) {
if isQuery {
return hc.postDqlQuery(query, paramsMap, timeout)
} else {
return hc.mutate(query, paramsMap)
}
}
// postDqlQuery sends a DQL query to the Dgraph server with query, parameters, and optional timeout.
// Returns the response body ([]byte) and an error, if any.
func (hc *DgraphClient) postDqlQuery(query string, paramsMap map[string]interface{}, timeout string) ([]byte, error) {
urlParams := url.Values{}
urlParams.Add("timeout", timeout)
url, err := getUrl(hc.baseUrl, "/query", urlParams)
if err != nil {
return nil, err
}
p := struct {
Query string `json:"query"`
Variables map[string]interface{} `json:"variables"`
}{
Query: query,
Variables: paramsMap,
}
body, err := json.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshlling json: %v", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err)
}
req.Header.Add("Content-Type", "application/json")
return hc.doReq(req)
}
// mutate sends an RDF mutation to the Dgraph server with "commitNow: true", embedding parameters.
// Returns the server's response as a byte slice or an error if the mutation fails.
func (hc *DgraphClient) mutate(mutation string, paramsMap map[string]interface{}) ([]byte, error) {
mu := embedParamsIntoMutation(mutation, paramsMap)
params := url.Values{}
params.Add("commitNow", "true")
url, err := getUrl(hc.baseUrl, "/mutate", params)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(mu))
if err != nil {
return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err)
}
req.Header.Add("Content-Type", "application/rdf")
return hc.doReq(req)
}
func (hc *DgraphClient) doReq(req *http.Request) ([]byte, error) {
if hc.HttpToken != nil {
req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt)
}
if hc.apiKey != "" {
req.Header.Set("Dg-Auth", hc.apiKey)
}
resp, err := hc.httpClient.Do(req)
if err != nil && !strings.Contains(err.Error(), "Token is expired") {
return nil, fmt.Errorf("error performing HTTP request: %w", err)
} else if err != nil && strings.Contains(err.Error(), "Token is expired") {
if errLogin := hc.loginWithToken(); errLogin != nil {
return nil, errLogin
}
if hc.HttpToken != nil {
req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt)
}
resp, err = hc.httpClient.Do(req)
if err != nil {
return nil, err
}
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: url: [%v], err: [%v]", req.URL, err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("got non 200 resp: %v", string(respBody))
}
return respBody, nil
}
func (hc *DgraphClient) loginWithCredentials() error {
credentials := map[string]interface{}{
"userid": hc.UserId,
"password": hc.Password,
"namespace": hc.Namespace,
}
return hc.doLogin(credentials)
}
func (hc *DgraphClient) loginWithToken() error {
credentials := map[string]interface{}{
"refreshJWT": hc.RefreshToken,
"namespace": hc.Namespace,
}
return hc.doLogin(credentials)
}
func (hc *DgraphClient) doLogin(creds map[string]interface{}) error {
url, err := getUrl(hc.baseUrl, "/login", nil)
if err != nil {
return err
}
payload, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("failed to marshal credentials: %v", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload))
if err != nil {
return fmt.Errorf("error building req for endpoint [%v] : %v", url, err)
}
req.Header.Add("Content-Type", "application/json")
if hc.apiKey != "" {
req.Header.Set("Dg-Auth", hc.apiKey)
}
resp, err := hc.doReq(req)
if err != nil {
if strings.Contains(err.Error(), "Token is expired") &&
!strings.Contains(err.Error(), "unable to authenticate the refresh token") {
return hc.loginWithToken()
}
return err
}
if err := CheckError(resp); err != nil {
return err
}
var r struct {
Data struct {
AccessJWT string `json:"accessJWT"`
RefreshJWT string `json:"refreshJWT"`
} `json:"data"`
}
if err := json.Unmarshal(resp, &r); err != nil {
return fmt.Errorf("failed to unmarshal response: %v", err)
}
if r.Data.AccessJWT == "" {
return fmt.Errorf("no access JWT found in the response")
}
if r.Data.RefreshJWT == "" {
return fmt.Errorf("no refresh JWT found in the response")
}
hc.AccessJwt = r.Data.AccessJWT
hc.RefreshToken = r.Data.RefreshJWT
return nil
}
func (hc *DgraphClient) healthCheck() error {
url, err := getUrl(hc.baseUrl, "/health", nil)
if err != nil {
return err
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
resp, err := hc.httpClient.Do(req)
if err != nil {
return fmt.Errorf("error performing request: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
var result []struct {
Instance string `json:"instance"`
Address string `json:"address"`
Status string `json:"status"`
}
// Unmarshal response into the struct
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("failed to unmarshal json: %v", err)
}
if len(result) == 0 {
return fmt.Errorf("health info should not empty for: %v", url)
}
var unhealthyErr error
for _, info := range result {
if info.Status != "healthy" {
unhealthyErr = fmt.Errorf("dgraph instance [%v] is not in healthy state, address is %v",
info.Instance, info.Address)
} else {
return nil
}
}
return unhealthyErr
}
func getUrl(baseUrl, resource string, params url.Values) (string, error) {
u, err := url.ParseRequestURI(baseUrl)
if err != nil {
return "", fmt.Errorf("failed to get url %v", err)
}
u.Path = resource
u.RawQuery = params.Encode()
return u.String(), nil
}
func CheckError(resp []byte) error {
var errResp struct {
Errors []struct {
Message string `json:"message"`
} `json:"errors"`
}
if err := json.Unmarshal(resp, &errResp); err != nil {
return fmt.Errorf("failed to unmarshal json: %v", err)
}
if len(errResp.Errors) > 0 {
return fmt.Errorf("error : %v", errResp.Errors)
}
return nil
}
func embedParamsIntoMutation(mutation string, paramsMap map[string]interface{}) string {
for key, value := range paramsMap {
mutation = strings.ReplaceAll(mutation, key, fmt.Sprintf(`"%v"`, value))
}
return mutation
}

View File

@@ -0,0 +1,76 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dgraph_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlDgraph(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
apiKey: abc123
password: pass@123
namespace: 0
user: user123
`,
want: server.SourceConfigs{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Kind: dgraph.SourceKind,
DgraphUrl: "https://localhost:8080",
ApiKey: "abc123",
Password: "pass@123",
Namespace: 0,
User: "user123",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -0,0 +1,133 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dgraph
import (
"encoding/json"
"fmt"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const ToolKind string = "dgraph-dql"
type compatibleSource interface {
DgraphClient() *dgraph.DgraphClient
}
// validate compatible sources are still compatible
var _ compatibleSource = &dgraph.Source{}
var compatibleSources = [...]string{dgraph.SourceKind}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
IsQuery bool `yaml:"isQuery"`
Timeout string `yaml:"timeout"`
Parameters tools.Parameters `yaml:"parameters"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return ToolKind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
DgraphClient: s.DgraphClient(),
IsQuery: cfg.IsQuery,
Timeout: cfg.Timeout,
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Parameters tools.Parameters `yaml:"parameters"`
AuthRequired []string `yaml:"authRequired"`
DgraphClient *dgraph.DgraphClient
IsQuery bool
Timeout string
Statement string
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
paramsMap := params.AsMapWithDollarPrefix()
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
if err != nil {
return "", err
}
if err := dgraph.CheckError(resp); err != nil {
return "", err
}
var result struct {
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(resp, &result); err != nil {
return "", fmt.Errorf("error parsing JSON: %v", err)
}
return fmt.Sprintf(
"Stub tool call for %q! Parameters parsed: %q \n Output: %v",
t.Name, paramsMap, result.Data,
), nil
}
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claimsMap)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
}

View File

@@ -0,0 +1,96 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dgraph_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
)
func TestParseFromYamlDgraph(t *testing.T) {
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic query example",
in: `
tools:
example_tool:
kind: dgraph-dql
source: my-dgraph-instance
description: some tool description
isQuery: true
timeout: 20s
statement: |
query {q(func: eq(email, "example@email.com")) {email}}
`,
want: server.ToolConfigs{
"example_tool": dgraph.Config{
Name: "example_tool",
Kind: dgraph.ToolKind,
Source: "my-dgraph-instance",
Description: "some tool description",
IsQuery: true,
Timeout: "20s",
Statement: "query {q(func: eq(email, \"example@email.com\")) {email}}\n",
},
},
},
{
desc: "basic mutation example",
in: `
tools:
example_tool:
kind: dgraph-dql
source: my-dgraph-instance
description: some tool description
statement: |
mutation {set { _:a <name> "a@email.com" . _:b <email> "b@email.com" .}}
`,
want: server.ToolConfigs{
"example_tool": dgraph.Config{
Name: "example_tool",
Kind: dgraph.ToolKind,
Source: "my-dgraph-instance",
Description: "some tool description",
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -17,6 +17,7 @@ package tools
import (
"encoding/json"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/util"
)
@@ -78,6 +79,23 @@ func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
return params
}
// AsMapWithDollarPrefix ensures all keys are prefixed with a dollar sign for Dgraph.
// Example:
// Input: {"role": "admin", "$age": 30}
// Output: {"$role": "admin", "$age": 30}
func (p ParamValues) AsMapWithDollarPrefix() map[string]interface{} {
params := make(map[string]interface{})
for _, param := range p {
key := param.Name
if !strings.HasPrefix(key, "$") {
key = "$" + key
}
params[key] = param.Value
}
return params
}
func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) {
// parse a parameter from claims using its specified auth sources
for _, a := range paramAuthSources {

View File

@@ -613,6 +613,7 @@ func TestParamValues(t *testing.T) {
wantSlice []any
wantMap map[string]interface{}
wantMapOrdered map[string]interface{}
wantMapWithDollar map[string]interface{}
}{
{
name: "string",
@@ -620,6 +621,10 @@ func TestParamValues(t *testing.T) {
wantSlice: []any{true, "hello world"},
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
wantMapOrdered: map[string]interface{}{"p1": true, "p2": "hello world"},
wantMapWithDollar: map[string]interface{}{
"$my_bool": true,
"$my_string": "hello world",
},
},
}
for _, tc := range tcs {
@@ -627,6 +632,7 @@ func TestParamValues(t *testing.T) {
gotSlice := tc.in.AsSlice()
gotMap := tc.in.AsMap()
gotMapOrdered := tc.in.AsMapByOrderedKeys()
gotMapWithDollar := tc.in.AsMapWithDollarPrefix()
for i, got := range gotSlice {
want := tc.wantSlice[i]
@@ -646,6 +652,12 @@ func TestParamValues(t *testing.T) {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
for key, got := range gotMapWithDollar {
want := tc.wantMapWithDollar[key]
if got != want {
t.Fatalf("unexpected value in AsMapWithDollarPrefix: got %q, want %q", got, want)
}
}
})
}
}

32
tests/dgraph.yaml Normal file
View File

@@ -0,0 +1,32 @@
sources:
dgraph-manage-user-instance:
kind: "dgraph"
dgraphUrl: "https://play.dgraph.io/query?latency=true"
apiKey: ""
tools:
search_user:
kind: dgraph-dql
source: dgraph-manage-user-instance
statement: |
query all($role: string){
users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) {
uid
name
email
role
age
}
}
isQuery: true
timeout: 20s
description: |
Use this tool to insert or update user data into the Dgraph database.
The mutation adds or updates user details like name, email, role, and age.
Example: Add users Alice and Bob as admins with specific ages.
parameters:
- name: role
type: string
description: admin

View File

@@ -0,0 +1,170 @@
//go:build integration && dgraph
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tests
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"reflect"
"regexp"
"testing"
"time"
)
var (
DGRAPH_URL = os.Getenv("DGRAPH_URL")
)
func requireDgraphVars(t *testing.T) {
if DGRAPH_URL =="" {
t.Fatal("'DGRAPH_URL' not set")
}
}
func TestDgraph(t *testing.T) {
requireDgraphVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-dgraph-instance": map[string]any{
"kind": "dgraph",
"dgraphUrl": DGRAPH_URL,
"apiKey": "api-key",
},
},
"tools": map[string]any{
"my-simple-dql-tool": map[string]any{
"kind": "dgraph-dql",
"source": "my-dgraph-instance",
"description": "Simple tool to test end to end functionality.",
"statement": "{result(func: uid(0x0)) {constant: math(1)}}",
"isQuery": true,
"timeout": "20s",
"parameters": []any{},
},
},
}
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Test tool get endpoint
tcs := []struct {
name string
api string
want map[string]any
}{
{
name: "get my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/",
want: map[string]any{
"my-simple-dql-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Get(tc.api)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("response status code is not 200")
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["tools"]
if !ok {
t.Fatalf("unable to find tools in response body")
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %q, want %q", got, tc.want)
}
})
}
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody io.Reader
want string
}{
{
name: "invoke my-simple-dql-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-dql-tool\"! Parameters parsed: map[]" +
" \n Output: map[result:[map[constant:1]]]",
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}