mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 23:18:04 -05:00
feat: add 'alloydb-ai-nl' tool (#358)
Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com> Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -66,6 +66,27 @@ steps:
|
||||
- |
|
||||
go test -race -v -tags=integration,alloydb ./tests
|
||||
|
||||
- id: "alloydb-ai-nl"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "ALLOYDB_AI_NL_PROJECT=$PROJECT_ID"
|
||||
- "ALLOYDB_AI_NL_CLUSTER=$_ALLOYDB_AI_NL_CLUSTER"
|
||||
- "ALLOYDB_AI_NL_INSTANCE=$_ALLOYDB_AI_NL_INSTANCE"
|
||||
- "ALLOYDB_AI_NL_DATABASE=$_DATABASE_NAME"
|
||||
- "ALLOYDB_AI_NL_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["ALLOYDB_AI_NL_USER", "ALLOYDB_AI_NL_PASS", "CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,alloydb_ai_nl ./tests
|
||||
|
||||
- id: "postgres"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
@@ -239,6 +260,10 @@ availableSecrets:
|
||||
env: ALLOYDB_POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/alloydb_pg_pass/versions/latest
|
||||
env: ALLOYDB_POSTGRES_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_user/versions/latest
|
||||
env: ALLOYDB_AI_NL_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_pass/versions/latest
|
||||
env: ALLOYDB_AI_NL_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/postgres_user/versions/latest
|
||||
env: POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
||||
@@ -281,6 +306,8 @@ substitutions:
|
||||
_CLOUD_SQL_POSTGRES_INSTANCE: "cloud-sql-pg-testing"
|
||||
_ALLOYDB_POSTGRES_CLUSTER: "alloydb-pg-testing"
|
||||
_ALLOYDB_POSTGRES_INSTANCE: "alloydb-pg-testing-instance"
|
||||
_ALLOYDB_AI_NL_CLUSTER: "alloydb-ai-nl-testing"
|
||||
_ALLOYDB_AI_NL_INSTANCE: "alloydb-ai-nl-testing-instance"
|
||||
_POSTGRES_HOST: 127.0.0.1
|
||||
_POSTGRES_PORT: "5432"
|
||||
_SPANNER_INSTANCE: "spanner-testing"
|
||||
|
||||
@@ -44,3 +44,4 @@ run:
|
||||
- mssql
|
||||
- mysql
|
||||
- http
|
||||
- alloydb_ai_nl
|
||||
|
||||
104
docs/en/resources/tools/alloydb-ai-nl.md
Normal file
104
docs/en/resources/tools/alloydb-ai-nl.md
Normal file
@@ -0,0 +1,104 @@
|
||||
---
|
||||
title: "alloydb-ai-nl"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "alloydb-ai-nl" tool leverages
|
||||
[AlloyDB AI](https://cloud.google.com/alloydb/ai) next-generation Natural
|
||||
Language support to provide the ability to query the database directly using
|
||||
natural language.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `alloydb-ai-nl` tool leverages [AlloyDB AI next-generation natural
|
||||
Language][alloydb-ai-nl-overview] support to allow an Agent the ability to query
|
||||
the database directly using natural language. Natural language streamlines the
|
||||
development of generative AI applications by transferring the complexity of
|
||||
converting natural language to SQL from the application layer to the database
|
||||
layer.
|
||||
|
||||
This tool is compatible with the following sources:
|
||||
- [alloydb-postgres](../sources/alloydb-pg.md)
|
||||
|
||||
AlloyDB AI Natural Language delivers secure and accurate responses for
|
||||
application end user natural language questions. Natural language streamlines
|
||||
the development of generative AI applications by transferring the complexity
|
||||
of converting natural language to SQL from the application layer to the
|
||||
database layer.
|
||||
|
||||
## Requirements
|
||||
{{< notice tip >}} AlloyDB AI natural language is currently in gated public
|
||||
preview. For more information on availability and limitations, please see
|
||||
[AlloyDB AI natural language
|
||||
overview](https://cloud.google.com/alloydb/docs/natural-language-questions-overview)
|
||||
{{< /notice >}}
|
||||
|
||||
To enable AlloyDB AI natural language for your AlloyDB cluster, please follow
|
||||
the steps listed in the [Generate SQL queries that answer natural language
|
||||
questions][alloydb-ai-gen-nl], including enabling the extension and configuring
|
||||
context for your application.
|
||||
|
||||
[alloydb-ai-nl-overview]: https://cloud.google.com/alloydb/docs/natural-language-questions-overview
|
||||
[alloydb-ai-gen-nl]: https://cloud.google.com/alloydb/docs/alloydb/docs/ai/generate-queries-natural-language
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
### Specifying an `nl_config`
|
||||
A `nl_config` is a configuration that associates an application to schema
|
||||
objects, examples and other contexts that can be used. A large application can
|
||||
also use different configurations for different parts of the app, as long as the
|
||||
correct configuration can be specified when a question is sent from that part of
|
||||
the application.
|
||||
|
||||
Once you've followed the steps for configuring context, you can use the
|
||||
`context` field when configuring a `alloydb-ai-nl` tool. When this tool is
|
||||
invoked, the SQL will be generated and executed using this context.
|
||||
|
||||
### Specifying Parameters to PSV's
|
||||
|
||||
[Parameterized Secure Views (PSVs)][alloydb-psv] are a feature unique to AlloyDB
|
||||
that allows you allow you to require one or more named parameter values passed
|
||||
to the view when querying it, somewhat like bind variables with ordinary
|
||||
database queries.
|
||||
|
||||
You can use the `nlConfigParameters` to list the parameters required for your
|
||||
`nl_config`. You **must** supply all parameters required for all PSVs in the
|
||||
context. It's strongly recommended to use features like [Authenticated
|
||||
Parameters](../tools/#array-parameters) or Bound Parameters to provide secure
|
||||
access to queries generated using natural language, as these parameters are not
|
||||
visible to the LLM.
|
||||
|
||||
[alloydb-psv]: https://cloud.google.com/alloydb/docs/ai/use-psvs#parameterized_secure_views
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
ask_questions:
|
||||
kind: alloydb-ai-nl
|
||||
source: my-alloydb-source
|
||||
description: "Ask questions to check information about flights"
|
||||
nlConfig: "cymbal_air_nl_config"
|
||||
nlConfigParameters:
|
||||
- name: user_email
|
||||
type: string
|
||||
description: User ID of the logged in user.
|
||||
# note: we strongly recommend using features like Authenticated or
|
||||
# Bound parameters to prevent the LLM from seeing these params and
|
||||
# specifying values it shouldn't in the tool input
|
||||
authServices:
|
||||
- name: my_google_service
|
||||
field: email
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "alloydb-ai-nl". |
|
||||
| source | string | true | Name of the AlloyDB source the natural language query should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| nlConfig | string | true | The name of the `nl_config` in AlloyDB |
|
||||
| nlConfigParameters | [parameters](_index#specifying-parameters) | true | List of PSV parameters defined in the `nl_config` |
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
|
||||
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
@@ -307,6 +308,12 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case alloydbainl.ToolKind:
|
||||
actual := alloydbainl.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case mysqlsql.ToolKind:
|
||||
actual := mysqlsql.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
|
||||
190
internal/tools/alloydbainl/alloydbainl.go
Normal file
190
internal/tools/alloydbainl/alloydbainl.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package alloydbainl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const ToolKind string = "alloydb-ai-nl"
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
NLConfig string `yaml:"nlConfig" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
NLConfigParameters tools.Parameters `yaml:"nlConfigParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
numParams := len(cfg.NLConfigParameters)
|
||||
quotedNameParts := make([]string, 0, numParams)
|
||||
placeholderParts := make([]string, 0, numParams)
|
||||
|
||||
for i, paramDef := range cfg.NLConfigParameters {
|
||||
name := paramDef.GetName()
|
||||
escapedName := strings.ReplaceAll(name, "'", "''") // Escape for SQL literal
|
||||
quotedNameParts = append(quotedNameParts, fmt.Sprintf("'%s'", escapedName))
|
||||
placeholderParts = append(placeholderParts, fmt.Sprintf("$%d", i+3)) // $1, $2 reserved
|
||||
}
|
||||
|
||||
var paramNamesSQL string
|
||||
var paramValuesSQL string
|
||||
|
||||
if numParams > 0 {
|
||||
paramNamesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(quotedNameParts, ", "))
|
||||
paramValuesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(placeholderParts, ", "))
|
||||
} else {
|
||||
paramNamesSQL = "ARRAY[]::TEXT[]"
|
||||
paramValuesSQL = "ARRAY[]::TEXT[]"
|
||||
}
|
||||
|
||||
// execute_nl_query is the AlloyDB AI function that executes the natural language query
|
||||
// The first parameter is the natural language query, which is passed as $1
|
||||
// The second parameter is the NLConfig, which is passed as a $2
|
||||
// The following params are the list of PSV values passed to the NLConfig
|
||||
// Example SQL statement being executed:
|
||||
// SELECT alloydb_ai_nl.execute_nl_query('How many tickets do I have?', 'cymbal_air_nl_config', param_names => ARRAY ['user_email'], param_values => ARRAY ['hailongli@google.com']);
|
||||
stmtFormat := "SELECT alloydb_ai_nl.execute_nl_query($1, $2, param_names => %s, param_values => %s);"
|
||||
stmt := fmt.Sprintf(stmtFormat, paramNamesSQL, paramValuesSQL)
|
||||
|
||||
newQuestionParam := tools.NewStringParameter(
|
||||
"question", // name
|
||||
"The natural language question to ask.", // description
|
||||
)
|
||||
|
||||
cfg.NLConfigParameters = append([]tools.Parameter{newQuestionParam}, cfg.NLConfigParameters...)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.NLConfigParameters.McpManifest(),
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.NLConfigParameters,
|
||||
Statement: stmt,
|
||||
NLConfig: cfg.NLConfig,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
NLConfig string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
allParamValues := make([]any, len(sliceParams)+1)
|
||||
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
|
||||
allParamValues[1] = fmt.Sprintf("%s", t.NLConfig) // nl_config
|
||||
for i, param := range sliceParams[1:] {
|
||||
allParamValues[i+2] = fmt.Sprintf("%s", param)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(context.Background(), t.Statement, allParamValues...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v", err, t.Statement, allParamValues)
|
||||
}
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, f := range fields {
|
||||
vMap[f.Name] = v[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
131
internal/tools/alloydbainl/alloydbainl_test.go
Normal file
131
internal/tools/alloydbainl/alloydbainl_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package alloydbainl_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
)
|
||||
|
||||
func TestParseFromYamlAlloyDBNLA(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: alloydb-ai-nl
|
||||
source: my-alloydb-instance
|
||||
description: AlloyDB natural language query tool
|
||||
nlConfig: 'my_nl_config'
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
nlConfigParameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: user_id to use
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: sub
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": alloydbainl.Config{
|
||||
Name: "example_tool",
|
||||
Kind: alloydbainl.ToolKind,
|
||||
Source: "my-alloydb-instance",
|
||||
Description: "AlloyDB natural language query tool",
|
||||
NLConfig: "my_nl_config",
|
||||
AuthRequired: []string{"my-google-auth-service"},
|
||||
NLConfigParameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("user_id", "user_id to use",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with multiple parameters",
|
||||
in: `
|
||||
tools:
|
||||
complex_tool:
|
||||
kind: alloydb-ai-nl
|
||||
source: my-alloydb-instance
|
||||
description: AlloyDB natural language query tool with multiple parameters
|
||||
nlConfig: 'complex_nl_config'
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
nlConfigParameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: user_id to use
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: sub
|
||||
- name: user_email
|
||||
type: string
|
||||
description: user_email to use
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_email
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"complex_tool": alloydbainl.Config{
|
||||
Name: "complex_tool",
|
||||
Kind: alloydbainl.ToolKind,
|
||||
Source: "my-alloydb-instance",
|
||||
Description: "AlloyDB natural language query tool with multiple parameters",
|
||||
NLConfig: "complex_nl_config",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
NLConfigParameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("user_id", "user_id to use",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}),
|
||||
tools.NewStringParameterWithAuth("user_email", "user_email to use",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_email"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
326
tests/alloydb_ai_nl_integration_test.go
Normal file
326
tests/alloydb_ai_nl_integration_test.go
Normal file
@@ -0,0 +1,326 @@
|
||||
//go:build integration && alloydb_ai_nl
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ALLOYDB_AI_NL_SOURCE_KIND = "alloydb-postgres"
|
||||
ALLOYDB_AI_NL_TOOL_KIND = "alloydb-ai-nl"
|
||||
ALLOYDB_AI_NL_PROJECT = os.Getenv("ALLOYDB_AI_NL_PROJECT")
|
||||
ALLOYDB_AI_NL_REGION = os.Getenv("ALLOYDB_AI_NL_REGION")
|
||||
ALLOYDB_AI_NL_CLUSTER = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
|
||||
ALLOYDB_AI_NL_INSTANCE = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
|
||||
ALLOYDB_AI_NL_DATABASE = os.Getenv("ALLOYDB_AI_NL_DATABASE")
|
||||
ALLOYDB_AI_NL_USER = os.Getenv("ALLOYDB_AI_NL_USER")
|
||||
ALLOYDB_AI_NL_PASS = os.Getenv("ALLOYDB_AI_NL_PASS")
|
||||
)
|
||||
|
||||
func getAlloyDBAiNlVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case ALLOYDB_AI_NL_PROJECT:
|
||||
t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set")
|
||||
case ALLOYDB_AI_NL_REGION:
|
||||
t.Fatal("'ALLOYDB_AI_NL_REGION' not set")
|
||||
case ALLOYDB_AI_NL_CLUSTER:
|
||||
t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set")
|
||||
case ALLOYDB_AI_NL_INSTANCE:
|
||||
t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set")
|
||||
case ALLOYDB_AI_NL_DATABASE:
|
||||
t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set")
|
||||
case ALLOYDB_AI_NL_USER:
|
||||
t.Fatal("'ALLOYDB_AI_NL_USER' not set")
|
||||
case ALLOYDB_AI_NL_PASS:
|
||||
t.Fatal("'ALLOYDB_AI_NL_PASS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": ALLOYDB_AI_NL_SOURCE_KIND,
|
||||
"project": ALLOYDB_AI_NL_PROJECT,
|
||||
"cluster": ALLOYDB_AI_NL_CLUSTER,
|
||||
"instance": ALLOYDB_AI_NL_INSTANCE,
|
||||
"region": ALLOYDB_AI_NL_REGION,
|
||||
"database": ALLOYDB_AI_NL_DATABASE,
|
||||
"user": ALLOYDB_AI_NL_USER,
|
||||
"password": ALLOYDB_AI_NL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlloyDBAiNlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getAlloyDBAiNlVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := getAiNlToolsConfig(sourceConfig)
|
||||
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runAiNlToolGetTest(t)
|
||||
|
||||
runAiNlToolInvokeTest(t)
|
||||
}
|
||||
|
||||
func runAiNlToolGetTest(t *testing.T) {
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The natural language question to ask.",
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(tc.api)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runAiNlToolInvokeTest(t *testing.T) {
|
||||
// Get ID token
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
want: "[{\"execute_nl_query\":{\"?column?\":1}}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "can you show me the name of this user?"}`)),
|
||||
want: "[{\"execute_nl_query\":{\"name\":\"Alice\"}}]",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
isErr: false,
|
||||
want: "[{\"execute_nl_query\":{\"?column?\":1}}]",
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr == true {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": ClientId,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"nlConfig": "my_nl_config",
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test authenticated parameters.",
|
||||
"nlConfig": "my_nl_config",
|
||||
"nlConfigParameters": []map[string]any{
|
||||
{
|
||||
"name": "email",
|
||||
"type": "string",
|
||||
"description": "user email",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth",
|
||||
"field": "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-auth-required-tool": map[string]any{
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test auth required invocation.",
|
||||
"nlConfig": "my_nl_config",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolsFile
|
||||
}
|
||||
Reference in New Issue
Block a user