From f02885fd4a919103fdabaa4ca38d975dc8497542 Mon Sep 17 00:00:00 2001 From: totoleon Date: Fri, 4 Apr 2025 11:30:58 -0700 Subject: [PATCH] feat: add 'alloydb-ai-nl' tool (#358) Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Averi Kitsch Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com> --- .ci/integration.cloudbuild.yaml | 27 ++ .golangci.yaml | 1 + docs/en/resources/tools/alloydb-ai-nl.md | 104 ++++++ internal/server/config.go | 7 + internal/tools/alloydbainl/alloydbainl.go | 190 ++++++++++ .../tools/alloydbainl/alloydbainl_test.go | 131 +++++++ tests/alloydb_ai_nl_integration_test.go | 326 ++++++++++++++++++ 7 files changed, 786 insertions(+) create mode 100644 docs/en/resources/tools/alloydb-ai-nl.md create mode 100644 internal/tools/alloydbainl/alloydbainl.go create mode 100644 internal/tools/alloydbainl/alloydbainl_test.go create mode 100644 tests/alloydb_ai_nl_integration_test.go diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index d976a77ba8..92558347c1 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -66,6 +66,27 @@ steps: - | go test -race -v -tags=integration,alloydb ./tests + - id: "alloydb-ai-nl" + name: golang:1 + waitFor: ["install-dependencies"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "ALLOYDB_AI_NL_PROJECT=$PROJECT_ID" + - "ALLOYDB_AI_NL_CLUSTER=$_ALLOYDB_AI_NL_CLUSTER" + - "ALLOYDB_AI_NL_INSTANCE=$_ALLOYDB_AI_NL_INSTANCE" + - "ALLOYDB_AI_NL_DATABASE=$_DATABASE_NAME" + - "ALLOYDB_AI_NL_REGION=$_REGION" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["ALLOYDB_AI_NL_USER", "ALLOYDB_AI_NL_PASS", "CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + go test -race -v -tags=integration,alloydb_ai_nl ./tests + - id: "postgres" name: golang:1 waitFor: ["install-dependencies"] @@ -239,6 +260,10 @@ availableSecrets: env: ALLOYDB_POSTGRES_USER - versionName: projects/$PROJECT_ID/secrets/alloydb_pg_pass/versions/latest env: ALLOYDB_POSTGRES_PASS + - versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_user/versions/latest + env: ALLOYDB_AI_NL_USER + - versionName: projects/$PROJECT_ID/secrets/alloydb_ai_nl_pass/versions/latest + env: ALLOYDB_AI_NL_PASS - versionName: projects/$PROJECT_ID/secrets/postgres_user/versions/latest env: POSTGRES_USER - versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest @@ -281,6 +306,8 @@ substitutions: _CLOUD_SQL_POSTGRES_INSTANCE: "cloud-sql-pg-testing" _ALLOYDB_POSTGRES_CLUSTER: "alloydb-pg-testing" _ALLOYDB_POSTGRES_INSTANCE: "alloydb-pg-testing-instance" + _ALLOYDB_AI_NL_CLUSTER: "alloydb-ai-nl-testing" + _ALLOYDB_AI_NL_INSTANCE: "alloydb-ai-nl-testing-instance" _POSTGRES_HOST: 127.0.0.1 _POSTGRES_PORT: "5432" _SPANNER_INSTANCE: "spanner-testing" diff --git a/.golangci.yaml b/.golangci.yaml index a384e6435a..6dbb7a296c 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -44,3 +44,4 @@ run: - mssql - mysql - http + - alloydb_ai_nl diff --git a/docs/en/resources/tools/alloydb-ai-nl.md b/docs/en/resources/tools/alloydb-ai-nl.md new file mode 100644 index 0000000000..15ad0ae4f9 --- /dev/null +++ b/docs/en/resources/tools/alloydb-ai-nl.md @@ -0,0 +1,104 @@ +--- +title: "alloydb-ai-nl" +type: docs +weight: 1 +description: > + The "alloydb-ai-nl" tool leverages + [AlloyDB AI](https://cloud.google.com/alloydb/ai) next-generation Natural + Language support to provide the ability to query the database directly using + natural language. +--- + +## About + +The `alloydb-ai-nl` tool leverages [AlloyDB AI next-generation natural +Language][alloydb-ai-nl-overview] support to allow an Agent the ability to query +the database directly using natural language. Natural language streamlines the +development of generative AI applications by transferring the complexity of +converting natural language to SQL from the application layer to the database +layer. + +This tool is compatible with the following sources: +- [alloydb-postgres](../sources/alloydb-pg.md) + +AlloyDB AI Natural Language delivers secure and accurate responses for +application end user natural language questions. Natural language streamlines +the development of generative AI applications by transferring the complexity +of converting natural language to SQL from the application layer to the +database layer. + +## Requirements +{{< notice tip >}} AlloyDB AI natural language is currently in gated public +preview. For more information on availability and limitations, please see +[AlloyDB AI natural language +overview](https://cloud.google.com/alloydb/docs/natural-language-questions-overview) +{{< /notice >}} + +To enable AlloyDB AI natural language for your AlloyDB cluster, please follow +the steps listed in the [Generate SQL queries that answer natural language +questions][alloydb-ai-gen-nl], including enabling the extension and configuring +context for your application. + +[alloydb-ai-nl-overview]: https://cloud.google.com/alloydb/docs/natural-language-questions-overview +[alloydb-ai-gen-nl]: https://cloud.google.com/alloydb/docs/alloydb/docs/ai/generate-queries-natural-language + + +## Configuration + +### Specifying an `nl_config` +A `nl_config` is a configuration that associates an application to schema +objects, examples and other contexts that can be used. A large application can +also use different configurations for different parts of the app, as long as the +correct configuration can be specified when a question is sent from that part of +the application. + +Once you've followed the steps for configuring context, you can use the +`context` field when configuring a `alloydb-ai-nl` tool. When this tool is +invoked, the SQL will be generated and executed using this context. + +### Specifying Parameters to PSV's + +[Parameterized Secure Views (PSVs)][alloydb-psv] are a feature unique to AlloyDB +that allows you allow you to require one or more named parameter values passed +to the view when querying it, somewhat like bind variables with ordinary +database queries. + +You can use the `nlConfigParameters` to list the parameters required for your +`nl_config`. You **must** supply all parameters required for all PSVs in the +context. It's strongly recommended to use features like [Authenticated +Parameters](../tools/#array-parameters) or Bound Parameters to provide secure +access to queries generated using natural language, as these parameters are not +visible to the LLM. + +[alloydb-psv]: https://cloud.google.com/alloydb/docs/ai/use-psvs#parameterized_secure_views + +## Example + +```yaml +tools: + ask_questions: + kind: alloydb-ai-nl + source: my-alloydb-source + description: "Ask questions to check information about flights" + nlConfig: "cymbal_air_nl_config" + nlConfigParameters: + - name: user_email + type: string + description: User ID of the logged in user. + # note: we strongly recommend using features like Authenticated or + # Bound parameters to prevent the LLM from seeing these params and + # specifying values it shouldn't in the tool input + authServices: + - name: my_google_service + field: email +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|--------------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------| +| kind | string | true | Must be "alloydb-ai-nl". | +| source | string | true | Name of the AlloyDB source the natural language query should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | +| nlConfig | string | true | The name of the `nl_config` in AlloyDB | +| nlConfigParameters | [parameters](_index#specifying-parameters) | true | List of PSV parameters defined in the `nl_config` | diff --git a/internal/server/config.go b/internal/server/config.go index 58facfea88..6c35f1abef 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -40,6 +40,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools/mysqlsql" neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j" "github.com/googleapis/genai-toolbox/internal/tools/postgressql" + "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" "github.com/googleapis/genai-toolbox/internal/tools/spanner" "github.com/googleapis/genai-toolbox/internal/util" ) @@ -307,6 +308,12 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual + case alloydbainl.ToolKind: + actual := alloydbainl.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) + } + (*c)[name] = actual case mysqlsql.ToolKind: actual := mysqlsql.Config{Name: name} if err := dec.DecodeContext(ctx, &actual); err != nil { diff --git a/internal/tools/alloydbainl/alloydbainl.go b/internal/tools/alloydbainl/alloydbainl.go new file mode 100644 index 0000000000..6877ba74f6 --- /dev/null +++ b/internal/tools/alloydbainl/alloydbainl.go @@ -0,0 +1,190 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package alloydbainl + +import ( + "context" + "fmt" + "strings" + + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/jackc/pgx/v5/pgxpool" +) + +const ToolKind string = "alloydb-ai-nl" + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &alloydbpg.Source{} + +var compatibleSources = [...]string{alloydbpg.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + NLConfig string `yaml:"nlConfig" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + NLConfigParameters tools.Parameters `yaml:"nlConfigParameters"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return ToolKind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources) + } + + numParams := len(cfg.NLConfigParameters) + quotedNameParts := make([]string, 0, numParams) + placeholderParts := make([]string, 0, numParams) + + for i, paramDef := range cfg.NLConfigParameters { + name := paramDef.GetName() + escapedName := strings.ReplaceAll(name, "'", "''") // Escape for SQL literal + quotedNameParts = append(quotedNameParts, fmt.Sprintf("'%s'", escapedName)) + placeholderParts = append(placeholderParts, fmt.Sprintf("$%d", i+3)) // $1, $2 reserved + } + + var paramNamesSQL string + var paramValuesSQL string + + if numParams > 0 { + paramNamesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(quotedNameParts, ", ")) + paramValuesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(placeholderParts, ", ")) + } else { + paramNamesSQL = "ARRAY[]::TEXT[]" + paramValuesSQL = "ARRAY[]::TEXT[]" + } + + // execute_nl_query is the AlloyDB AI function that executes the natural language query + // The first parameter is the natural language query, which is passed as $1 + // The second parameter is the NLConfig, which is passed as a $2 + // The following params are the list of PSV values passed to the NLConfig + // Example SQL statement being executed: + // SELECT alloydb_ai_nl.execute_nl_query('How many tickets do I have?', 'cymbal_air_nl_config', param_names => ARRAY ['user_email'], param_values => ARRAY ['hailongli@google.com']); + stmtFormat := "SELECT alloydb_ai_nl.execute_nl_query($1, $2, param_names => %s, param_values => %s);" + stmt := fmt.Sprintf(stmtFormat, paramNamesSQL, paramValuesSQL) + + newQuestionParam := tools.NewStringParameter( + "question", // name + "The natural language question to ask.", // description + ) + + cfg.NLConfigParameters = append([]tools.Parameter{newQuestionParam}, cfg.NLConfigParameters...) + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: cfg.Description, + InputSchema: cfg.NLConfigParameters.McpManifest(), + } + + t := Tool{ + Name: cfg.Name, + Kind: ToolKind, + Parameters: cfg.NLConfigParameters, + Statement: stmt, + NLConfig: cfg.NLConfig, + AuthRequired: cfg.AuthRequired, + Pool: s.PostgresPool(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest()}, + mcpManifest: mcpManifest, + } + + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + Parameters tools.Parameters `yaml:"parameters"` + + Pool *pgxpool.Pool + Statement string + NLConfig string + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(params tools.ParamValues) ([]any, error) { + sliceParams := params.AsSlice() + allParamValues := make([]any, len(sliceParams)+1) + allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question + allParamValues[1] = fmt.Sprintf("%s", t.NLConfig) // nl_config + for i, param := range sliceParams[1:] { + allParamValues[i+2] = fmt.Sprintf("%s", param) + } + + results, err := t.Pool.Query(context.Background(), t.Statement, allParamValues...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v", err, t.Statement, allParamValues) + } + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, f := range fields { + vMap[f.Name] = v[i] + } + out = append(out, vMap) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} diff --git a/internal/tools/alloydbainl/alloydbainl_test.go b/internal/tools/alloydbainl/alloydbainl_test.go new file mode 100644 index 0000000000..50ebe6e443 --- /dev/null +++ b/internal/tools/alloydbainl/alloydbainl_test.go @@ -0,0 +1,131 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package alloydbainl_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" +) + +func TestParseFromYamlAlloyDBNLA(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: alloydb-ai-nl + source: my-alloydb-instance + description: AlloyDB natural language query tool + nlConfig: 'my_nl_config' + authRequired: + - my-google-auth-service + nlConfigParameters: + - name: user_id + type: string + description: user_id to use + authServices: + - name: my-google-auth-service + field: sub + `, + want: server.ToolConfigs{ + "example_tool": alloydbainl.Config{ + Name: "example_tool", + Kind: alloydbainl.ToolKind, + Source: "my-alloydb-instance", + Description: "AlloyDB natural language query tool", + NLConfig: "my_nl_config", + AuthRequired: []string{"my-google-auth-service"}, + NLConfigParameters: []tools.Parameter{ + tools.NewStringParameterWithAuth("user_id", "user_id to use", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}), + }, + }, + }, + }, + { + desc: "with multiple parameters", + in: ` + tools: + complex_tool: + kind: alloydb-ai-nl + source: my-alloydb-instance + description: AlloyDB natural language query tool with multiple parameters + nlConfig: 'complex_nl_config' + authRequired: + - my-google-auth-service + - other-auth-service + nlConfigParameters: + - name: user_id + type: string + description: user_id to use + authServices: + - name: my-google-auth-service + field: sub + - name: user_email + type: string + description: user_email to use + authServices: + - name: my-google-auth-service + field: user_email + `, + want: server.ToolConfigs{ + "complex_tool": alloydbainl.Config{ + Name: "complex_tool", + Kind: alloydbainl.ToolKind, + Source: "my-alloydb-instance", + Description: "AlloyDB natural language query tool with multiple parameters", + NLConfig: "complex_nl_config", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + NLConfigParameters: []tools.Parameter{ + tools.NewStringParameterWithAuth("user_id", "user_id to use", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "sub"}}), + tools.NewStringParameterWithAuth("user_email", "user_email to use", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_email"}}), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/alloydb_ai_nl_integration_test.go b/tests/alloydb_ai_nl_integration_test.go new file mode 100644 index 0000000000..53c081fc4e --- /dev/null +++ b/tests/alloydb_ai_nl_integration_test.go @@ -0,0 +1,326 @@ +//go:build integration && alloydb_ai_nl + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "reflect" + "regexp" + "testing" + "time" +) + +var ( + ALLOYDB_AI_NL_SOURCE_KIND = "alloydb-postgres" + ALLOYDB_AI_NL_TOOL_KIND = "alloydb-ai-nl" + ALLOYDB_AI_NL_PROJECT = os.Getenv("ALLOYDB_AI_NL_PROJECT") + ALLOYDB_AI_NL_REGION = os.Getenv("ALLOYDB_AI_NL_REGION") + ALLOYDB_AI_NL_CLUSTER = os.Getenv("ALLOYDB_AI_NL_CLUSTER") + ALLOYDB_AI_NL_INSTANCE = os.Getenv("ALLOYDB_AI_NL_INSTANCE") + ALLOYDB_AI_NL_DATABASE = os.Getenv("ALLOYDB_AI_NL_DATABASE") + ALLOYDB_AI_NL_USER = os.Getenv("ALLOYDB_AI_NL_USER") + ALLOYDB_AI_NL_PASS = os.Getenv("ALLOYDB_AI_NL_PASS") +) + +func getAlloyDBAiNlVars(t *testing.T) map[string]any { + switch "" { + case ALLOYDB_AI_NL_PROJECT: + t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set") + case ALLOYDB_AI_NL_REGION: + t.Fatal("'ALLOYDB_AI_NL_REGION' not set") + case ALLOYDB_AI_NL_CLUSTER: + t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set") + case ALLOYDB_AI_NL_INSTANCE: + t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set") + case ALLOYDB_AI_NL_DATABASE: + t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set") + case ALLOYDB_AI_NL_USER: + t.Fatal("'ALLOYDB_AI_NL_USER' not set") + case ALLOYDB_AI_NL_PASS: + t.Fatal("'ALLOYDB_AI_NL_PASS' not set") + } + return map[string]any{ + "kind": ALLOYDB_AI_NL_SOURCE_KIND, + "project": ALLOYDB_AI_NL_PROJECT, + "cluster": ALLOYDB_AI_NL_CLUSTER, + "instance": ALLOYDB_AI_NL_INSTANCE, + "region": ALLOYDB_AI_NL_REGION, + "database": ALLOYDB_AI_NL_DATABASE, + "user": ALLOYDB_AI_NL_USER, + "password": ALLOYDB_AI_NL_PASS, + } +} + +func TestAlloyDBAiNlToolEndpoints(t *testing.T) { + sourceConfig := getAlloyDBAiNlVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + // Write config into a file and pass it to command + toolsFile := getAiNlToolsConfig(sourceConfig) + + cmd, cleanup, err := StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`)) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + runAiNlToolGetTest(t) + + runAiNlToolInvokeTest(t) +} + +func runAiNlToolGetTest(t *testing.T) { + // Test tool get endpoint + tcs := []struct { + name string + api string + want map[string]any + }{ + { + name: "get my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-tool/", + want: map[string]any{ + "my-simple-tool": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "parameters": []any{ + map[string]any{ + "name": "question", + "type": "string", + "description": "The natural language question to ask.", + "authSources": []any{}, + }, + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Get(tc.api) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("response status code is not 200") + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["tools"] + if !ok { + t.Fatalf("unable to find tools in response body") + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func runAiNlToolInvokeTest(t *testing.T) { + // Get ID token + idToken, err := GetGoogleIdToken(ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + // Test tool invoke endpoint + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + want: "[{\"execute_nl_query\":{\"?column?\":1}}]", + isErr: false, + }, + { + name: "Invoke my-tool without parameters", + api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "Invoke my-auth-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{"question": "can you show me the name of this user?"}`)), + want: "[{\"execute_nl_query\":{\"name\":\"Alice\"}}]", + isErr: false, + }, + { + name: "Invoke my-auth-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + isErr: true, + }, + { + name: "Invoke my-auth-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + isErr: true, + }, + { + name: "Invoke my-auth-required-tool with auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + isErr: false, + want: "[{\"execute_nl_query\":{\"?column?\":1}}]", + }, + { + name: "Invoke my-auth-required-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + isErr: true, + }, + { + name: "Invoke my-auth-required-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{"question": "return 1"}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + // Send Tool invocation request + req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + for k, v := range tc.requestHeader { + req.Header.Add(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if tc.isErr == true { + return + } + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Check response body + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if got != tc.want { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + }) + } + +} + +func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any { + // Write config into a file and pass it to command + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "kind": "google", + "clientId": ClientId, + }, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "kind": ALLOYDB_AI_NL_TOOL_KIND, + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + "nlConfig": "my_nl_config", + }, + "my-auth-tool": map[string]any{ + "kind": ALLOYDB_AI_NL_TOOL_KIND, + "source": "my-instance", + "description": "Tool to test authenticated parameters.", + "nlConfig": "my_nl_config", + "nlConfigParameters": []map[string]any{ + { + "name": "email", + "type": "string", + "description": "user email", + "authServices": []map[string]string{ + { + "name": "my-google-auth", + "field": "email", + }, + }, + }, + }, + }, + "my-auth-required-tool": map[string]any{ + "kind": ALLOYDB_AI_NL_TOOL_KIND, + "source": "my-instance", + "description": "Tool to test auth required invocation.", + "nlConfig": "my_nl_config", + "authRequired": []string{ + "my-google-auth", + }, + }, + }, + } + + return toolsFile +}