feat: add 'alloydb-ai-nl' tool (#358)

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
Co-authored-by: Averi Kitsch <akitsch@google.com>
Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
totoleon
2025-04-04 11:30:58 -07:00
committed by GitHub
parent a7d1d4eb2a
commit f02885fd4a
7 changed files with 786 additions and 0 deletions

View File

@@ -66,6 +66,27 @@ steps:
- |
go test -race -v -tags=integration,alloydb ./tests
- id: "alloydb-ai-nl"
name: golang:1
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "ALLOYDB_AI_NL_PROJECT=$PROJECT_ID"
- "ALLOYDB_AI_NL_CLUSTER=$_ALLOYDB_AI_NL_CLUSTER"
- "ALLOYDB_AI_NL_INSTANCE=$_ALLOYDB_AI_NL_INSTANCE"
- "ALLOYDB_AI_NL_DATABASE=$_DATABASE_NAME"
- "ALLOYDB_AI_NL_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["ALLOYDB_AI_NL_USER", "ALLOYDB_AI_NL_PASS", "CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
go test -race -v -tags=integration,alloydb_ai_nl ./tests
- id: "postgres"
name: golang:1
waitFor: ["install-dependencies"]
@@ -239,6 +260,10 @@ availableSecrets:
env: ALLOYDB_POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/alloydb_pg_pass/versions/latest
env: ALLOYDB_POSTGRES_PASS
- versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_user/versions/latest
env: ALLOYDB_AI_NL_USER
- versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_pass/versions/latest
env: ALLOYDB_AI_NL_PASS
- versionName: projects/$PROJECT_ID/secrets/postgres_user/versions/latest
env: POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
@@ -281,6 +306,8 @@ substitutions:
_CLOUD_SQL_POSTGRES_INSTANCE: "cloud-sql-pg-testing"
_ALLOYDB_POSTGRES_CLUSTER: "alloydb-pg-testing"
_ALLOYDB_POSTGRES_INSTANCE: "alloydb-pg-testing-instance"
_ALLOYDB_AI_NL_CLUSTER: "alloydb-ai-nl-testing"
_ALLOYDB_AI_NL_INSTANCE: "alloydb-ai-nl-testing-instance"
_POSTGRES_HOST: 127.0.0.1
_POSTGRES_PORT: "5432"
_SPANNER_INSTANCE: "spanner-testing"

View File

@@ -44,3 +44,4 @@ run:
- mssql
- mysql
- http
- alloydb_ai_nl

View File

@@ -0,0 +1,104 @@
---
title: "alloydb-ai-nl"
type: docs
weight: 1
description: >
The "alloydb-ai-nl" tool leverages
[AlloyDB AI](https://cloud.google.com/alloydb/ai) next-generation Natural
Language support to provide the ability to query the database directly using
natural language.
---
## About
The `alloydb-ai-nl` tool leverages [AlloyDB AI next-generation natural
Language][alloydb-ai-nl-overview] support to allow an Agent the ability to query
the database directly using natural language. Natural language streamlines the
development of generative AI applications by transferring the complexity of
converting natural language to SQL from the application layer to the database
layer.
This tool is compatible with the following sources:
- [alloydb-postgres](../sources/alloydb-pg.md)
AlloyDB AI Natural Language delivers secure and accurate responses for
application end user natural language questions. Natural language streamlines
the development of generative AI applications by transferring the complexity
of converting natural language to SQL from the application layer to the
database layer.
## Requirements
{{< notice tip >}} AlloyDB AI natural language is currently in gated public
preview. For more information on availability and limitations, please see
[AlloyDB AI natural language
overview](https://cloud.google.com/alloydb/docs/natural-language-questions-overview)
{{< /notice >}}
To enable AlloyDB AI natural language for your AlloyDB cluster, please follow
the steps listed in the [Generate SQL queries that answer natural language
questions][alloydb-ai-gen-nl], including enabling the extension and configuring
context for your application.
[alloydb-ai-nl-overview]: https://cloud.google.com/alloydb/docs/natural-language-questions-overview
[alloydb-ai-gen-nl]: https://cloud.google.com/alloydb/docs/alloydb/docs/ai/generate-queries-natural-language
## Configuration
### Specifying an `nl_config`
A `nl_config` is a configuration that associates an application to schema
objects, examples and other contexts that can be used. A large application can
also use different configurations for different parts of the app, as long as the
correct configuration can be specified when a question is sent from that part of
the application.
Once you've followed the steps for configuring context, you can use the
`context` field when configuring a `alloydb-ai-nl` tool. When this tool is
invoked, the SQL will be generated and executed using this context.
### Specifying Parameters to PSV's
[Parameterized Secure Views (PSVs)][alloydb-psv] are a feature unique to AlloyDB
that allows you allow you to require one or more named parameter values passed
to the view when querying it, somewhat like bind variables with ordinary
database queries.
You can use the `nlConfigParameters` to list the parameters required for your
`nl_config`. You **must** supply all parameters required for all PSVs in the
context. It's strongly recommended to use features like [Authenticated
Parameters](../tools/#array-parameters) or Bound Parameters to provide secure
access to queries generated using natural language, as these parameters are not
visible to the LLM.
[alloydb-psv]: https://cloud.google.com/alloydb/docs/ai/use-psvs#parameterized_secure_views
## Example
```yaml
tools:
ask_questions:
kind: alloydb-ai-nl
source: my-alloydb-source
description: "Ask questions to check information about flights"
nlConfig: "cymbal_air_nl_config"
nlConfigParameters:
- name: user_email
type: string
description: User ID of the logged in user.
# note: we strongly recommend using features like Authenticated or
# Bound parameters to prevent the LLM from seeing these params and
# specifying values it shouldn't in the tool input
authServices:
- name: my_google_service
field: email
```
## Reference
| **field** | **type** | **required** | **description** |
|--------------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------|
| kind | string | true | Must be "alloydb-ai-nl". |
| source | string | true | Name of the AlloyDB source the natural language query should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |
| nlConfig | string | true | The name of the `nl_config` in AlloyDB |
| nlConfigParameters | [parameters](_index#specifying-parameters) | true | List of PSV parameters defined in the `nl_config` |

View File

@@ -40,6 +40,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
"github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
"github.com/googleapis/genai-toolbox/internal/util"
)
@@ -307,6 +308,12 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
case alloydbainl.ToolKind:
actual := alloydbainl.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
case mysqlsql.ToolKind:
actual := mysqlsql.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {

View File

@@ -0,0 +1,190 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package alloydbainl
import (
"context"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool"
)
const ToolKind string = "alloydb-ai-nl"
type compatibleSource interface {
PostgresPool() *pgxpool.Pool
}
// validate compatible sources are still compatible
var _ compatibleSource = &alloydbpg.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
NLConfig string `yaml:"nlConfig" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
NLConfigParameters tools.Parameters `yaml:"nlConfigParameters"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return ToolKind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
numParams := len(cfg.NLConfigParameters)
quotedNameParts := make([]string, 0, numParams)
placeholderParts := make([]string, 0, numParams)
for i, paramDef := range cfg.NLConfigParameters {
name := paramDef.GetName()
escapedName := strings.ReplaceAll(name, "'", "''") // Escape for SQL literal
quotedNameParts = append(quotedNameParts, fmt.Sprintf("'%s'", escapedName))
placeholderParts = append(placeholderParts, fmt.Sprintf("$%d", i+3)) // $1, $2 reserved
}
var paramNamesSQL string
var paramValuesSQL string
if numParams > 0 {
paramNamesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(quotedNameParts, ", "))
paramValuesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(placeholderParts, ", "))
} else {
paramNamesSQL = "ARRAY[]::TEXT[]"
paramValuesSQL = "ARRAY[]::TEXT[]"
}
// execute_nl_query is the AlloyDB AI function that executes the natural language query
// The first parameter is the natural language query, which is passed as $1
// The second parameter is the NLConfig, which is passed as a $2
// The following params are the list of PSV values passed to the NLConfig
// Example SQL statement being executed:
// SELECT alloydb_ai_nl.execute_nl_query('How many tickets do I have?', 'cymbal_air_nl_config', param_names => ARRAY ['user_email'], param_values => ARRAY ['hailongli@google.com']);
stmtFormat := "SELECT alloydb_ai_nl.execute_nl_query($1, $2, param_names => %s, param_values => %s);"
stmt := fmt.Sprintf(stmtFormat, paramNamesSQL, paramValuesSQL)
newQuestionParam := tools.NewStringParameter(
"question", // name
"The natural language question to ask.", // description
)
cfg.NLConfigParameters = append([]tools.Parameter{newQuestionParam}, cfg.NLConfigParameters...)
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.NLConfigParameters.McpManifest(),
}
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.NLConfigParameters,
Statement: stmt,
NLConfig: cfg.NLConfig,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Statement string
NLConfig string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
allParamValues[1] = fmt.Sprintf("%s", t.NLConfig) // nl_config
for i, param := range sliceParams[1:] {
allParamValues[i+2] = fmt.Sprintf("%s", param)
}
results, err := t.Pool.Query(context.Background(), t.Statement, allParamValues...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v", err, t.Statement, allParamValues)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -0,0 +1,131 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package alloydbainl_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
)
func TestParseFromYamlAlloyDBNLA(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: alloydb-ai-nl
source: my-alloydb-instance
description: AlloyDB natural language query tool
nlConfig: 'my_nl_config'
authRequired:
- my-google-auth-service
nlConfigParameters:
- name: user_id
type: string
description: user_id to use
authServices:
- name: my-google-auth-service
field: sub
`,
want: server.ToolConfigs{
"example_tool": alloydbainl.Config{
Name: "example_tool",
Kind: alloydbainl.ToolKind,
Source: "my-alloydb-instance",
Description: "AlloyDB natural language query tool",
NLConfig: "my_nl_config",
AuthRequired: []string{"my-google-auth-service"},
NLConfigParameters: []tools.Parameter{
tools.NewStringParameterWithAuth("user_id", "user_id to use",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}),
},
},
},
},
{
desc: "with multiple parameters",
in: `
tools:
complex_tool:
kind: alloydb-ai-nl
source: my-alloydb-instance
description: AlloyDB natural language query tool with multiple parameters
nlConfig: 'complex_nl_config'
authRequired:
- my-google-auth-service
- other-auth-service
nlConfigParameters:
- name: user_id
type: string
description: user_id to use
authServices:
- name: my-google-auth-service
field: sub
- name: user_email
type: string
description: user_email to use
authServices:
- name: my-google-auth-service
field: user_email
`,
want: server.ToolConfigs{
"complex_tool": alloydbainl.Config{
Name: "complex_tool",
Kind: alloydbainl.ToolKind,
Source: "my-alloydb-instance",
Description: "AlloyDB natural language query tool with multiple parameters",
NLConfig: "complex_nl_config",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
NLConfigParameters: []tools.Parameter{
tools.NewStringParameterWithAuth("user_id", "user_id to use",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}),
tools.NewStringParameterWithAuth("user_email", "user_email to use",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_email"}}),
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -0,0 +1,326 @@
//go:build integration && alloydb_ai_nl
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tests
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"reflect"
"regexp"
"testing"
"time"
)
var (
ALLOYDB_AI_NL_SOURCE_KIND = "alloydb-postgres"
ALLOYDB_AI_NL_TOOL_KIND = "alloydb-ai-nl"
ALLOYDB_AI_NL_PROJECT = os.Getenv("ALLOYDB_AI_NL_PROJECT")
ALLOYDB_AI_NL_REGION = os.Getenv("ALLOYDB_AI_NL_REGION")
ALLOYDB_AI_NL_CLUSTER = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
ALLOYDB_AI_NL_INSTANCE = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
ALLOYDB_AI_NL_DATABASE = os.Getenv("ALLOYDB_AI_NL_DATABASE")
ALLOYDB_AI_NL_USER = os.Getenv("ALLOYDB_AI_NL_USER")
ALLOYDB_AI_NL_PASS = os.Getenv("ALLOYDB_AI_NL_PASS")
)
func getAlloyDBAiNlVars(t *testing.T) map[string]any {
switch "" {
case ALLOYDB_AI_NL_PROJECT:
t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set")
case ALLOYDB_AI_NL_REGION:
t.Fatal("'ALLOYDB_AI_NL_REGION' not set")
case ALLOYDB_AI_NL_CLUSTER:
t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set")
case ALLOYDB_AI_NL_INSTANCE:
t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set")
case ALLOYDB_AI_NL_DATABASE:
t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set")
case ALLOYDB_AI_NL_USER:
t.Fatal("'ALLOYDB_AI_NL_USER' not set")
case ALLOYDB_AI_NL_PASS:
t.Fatal("'ALLOYDB_AI_NL_PASS' not set")
}
return map[string]any{
"kind": ALLOYDB_AI_NL_SOURCE_KIND,
"project": ALLOYDB_AI_NL_PROJECT,
"cluster": ALLOYDB_AI_NL_CLUSTER,
"instance": ALLOYDB_AI_NL_INSTANCE,
"region": ALLOYDB_AI_NL_REGION,
"database": ALLOYDB_AI_NL_DATABASE,
"user": ALLOYDB_AI_NL_USER,
"password": ALLOYDB_AI_NL_PASS,
}
}
func TestAlloyDBAiNlToolEndpoints(t *testing.T) {
sourceConfig := getAlloyDBAiNlVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
// Write config into a file and pass it to command
toolsFile := getAiNlToolsConfig(sourceConfig)
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
runAiNlToolGetTest(t)
runAiNlToolInvokeTest(t)
}
func runAiNlToolGetTest(t *testing.T) {
// Test tool get endpoint
tcs := []struct {
name string
api string
want map[string]any
}{
{
name: "get my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
want: map[string]any{
"my-simple-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{
map[string]any{
"name": "question",
"type": "string",
"description": "The natural language question to ask.",
"authSources": []any{},
},
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Get(tc.api)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("response status code is not 200")
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["tools"]
if !ok {
t.Fatalf("unable to find tools in response body")
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %q, want %q", got, tc.want)
}
})
}
}
func runAiNlToolInvokeTest(t *testing.T) {
// Get ID token
idToken, err := GetGoogleIdToken(ClientId)
if err != nil {
t.Fatalf("error getting Google ID token: %s", err)
}
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestHeader map[string]string
requestBody io.Reader
want string
isErr bool
}{
{
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
want: "[{\"execute_nl_query\":{\"?column?\":1}}]",
isErr: false,
},
{
name: "Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
},
{
name: "Invoke my-auth-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{"question": "can you show me the name of this user?"}`)),
want: "[{\"execute_nl_query\":{\"name\":\"Alice\"}}]",
isErr: false,
},
{
name: "Invoke my-auth-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
isErr: true,
},
{
name: "Invoke my-auth-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
isErr: true,
},
{
name: "Invoke my-auth-required-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
isErr: false,
want: "[{\"execute_nl_query\":{\"?column?\":1}}]",
},
{
name: "Invoke my-auth-required-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
isErr: true,
},
{
name: "Invoke my-auth-required-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
isErr: true,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if tc.isErr == true {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any {
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"authServices": map[string]any{
"my-google-auth": map[string]any{
"kind": "google",
"clientId": ClientId,
},
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Simple tool to test end to end functionality.",
"nlConfig": "my_nl_config",
},
"my-auth-tool": map[string]any{
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Tool to test authenticated parameters.",
"nlConfig": "my_nl_config",
"nlConfigParameters": []map[string]any{
{
"name": "email",
"type": "string",
"description": "user email",
"authServices": []map[string]string{
{
"name": "my-google-auth",
"field": "email",
},
},
},
},
},
"my-auth-required-tool": map[string]any{
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Tool to test auth required invocation.",
"nlConfig": "my_nl_config",
"authRequired": []string{
"my-google-auth",
},
},
},
}
return toolsFile
}