From a48101b3c5fa68fcf657b6f49ad897b8facf85d7 Mon Sep 17 00:00:00 2001 From: RUN Date: Wed, 28 Jan 2026 21:19:25 +0100 Subject: [PATCH] feat: refactor Oracle integration tests and enhance environment variable handling --- tests/oracle/oracle_integration_test.go | 447 +++++++++++++++--------- 1 file changed, 273 insertions(+), 174 deletions(-) diff --git a/tests/oracle/oracle_integration_test.go b/tests/oracle/oracle_integration_test.go index 0d410b41df..290826f9df 100644 --- a/tests/oracle/oracle_integration_test.go +++ b/tests/oracle/oracle_integration_test.go @@ -16,49 +16,200 @@ import ( "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/tests" ) + var ( OracleSourceType = "oracle" OracleToolType = "oracle-sql" OracleHost = os.Getenv("ORACLE_HOST") + OraclePort = os.Getenv("ORACLE_PORT") OracleUser = os.Getenv("ORACLE_USER") OraclePass = os.Getenv("ORACLE_PASS") OracleServerName = os.Getenv("ORACLE_SERVER_NAME") - OraclePort = os.Getenv("ORACLE_PORT") - OracleService = os.Getenv("ORACLE_SERVICE") - OracleWalletPath = os.Getenv("ORACLE_WALLET_PATH") - oracleUseOCI = os.Getenv("ORACLE_USE_OCI") - + OracleUseOCI = os.Getenv("ORACLE_USE_OCI") + OracleWalletLocation = os.Getenv("ORACLE_WALLET_LOCATION") + OracleTnsAdmin = os.Getenv("ORACLE_TNS_ADMIN") + OracleConnStr = fmt.Sprintf( - "%s:%s/%s", OracleHost, "1521", OracleServerName) + "%s:%s/%s", OracleHost, "1521", OracleServerName) // Default port 1521?? ) func getOracleVars(t *testing.T) map[string]any { switch "" { - case oracleHost: + case OracleHost: t.Fatal("'ORACLE_HOST not set") - case oracleUser: + case OracleUser: t.Fatal("'ORACLE_USER' not set") - case oraclePassword: - t.Fatal("'ORACLE_PASSWORD' not set") - case oraclePort: - t.Fatal("'ORACLE_PORT' not set") - case oracleService: - t.Fatal("'ORACLE_SERVICE' not set") + case OraclePass: + t.Fatal("'ORACLE_PASS' not set") + case OracleServerName: + t.Fatal("'ORACLE_SERVER_NAME' not set") } return map[string]any{ "type": OracleSourceType, "connectionString": OracleConnStr, - "useOCI": true, + "useOCI": OracleUseOCI, + "walletLocation": OracleWalletLocation, + "tnsAdmin": OracleTnsAdmin, + "host": OracleHost, + "port": OraclePort, + "service": OracleServerName, "user": OracleUser, "password": OraclePass, } } +// getOracleConfigFromEnv constructs an oracle.Config from environment variables. +func getOracleConfigFromEnv(t *testing.T) oracle.Config { + t.Helper() + vars := getOracleVars(t) + + port, err := strconv.Atoi(vars["port"].(string)) + if err != nil && vars["port"].(string) != "" { + t.Fatalf("invalid ORACLE_PORT: %v", err) + } + + useOCI, err := strconv.ParseBool(vars["ORACLE_USE_OCI"].(string)) + if err != nil && vars["ORACLE_USE_OCI"].(string) != "" { + useOCI = false + } + + return oracle.Config{ + Name: "test-oracle-instance", + Kind: vars["kind"].(string), + User: vars["user"].(string), + Password: vars["password"].(string), + Host: vars["host"].(string), + Port: port, + ServiceName: vars["service"].(string), + WalletLocation: vars["walletLocation"].(string), + TnsAdmin: vars["tnsAdmin"].(string), + UseOCI: useOCI, + } +} + +// setOracleEnv sets Oracle-related environment variables for testing and returns a cleanup function. +func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() { + t.Helper() + + original := map[string]string{ + "ORACLE_HOST": os.Getenv("ORACLE_HOST"), + "ORACLE_USER": os.Getenv("ORACLE_USER"), + "ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"), + "ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"), + "ORACLE_PORT": os.Getenv("ORACLE_PORT"), + "ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"), + "ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"), + "ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"), + } + + os.Setenv("ORACLE_HOST", host) + os.Setenv("ORACLE_USER", user) + os.Setenv("ORACLE_PASSWORD", password) + os.Setenv("ORACLE_SERVICE", service) + os.Setenv("ORACLE_PORT", port) + os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin) + os.Setenv("ORACLE_WALLET_LOCATION", walletLocation) + os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI)) + + return func() { + for k, v := range original { + os.Setenv(k, v) + } + } +} + +// Copied over from oracle.go +func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) { + // Build the full Oracle connection string for godror driver + fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + user, pass, connStr) + + db, err := sql.Open("godror", fullConnStr) + if err != nil { + return nil, fmt.Errorf("unable to open Oracle connection: %w", err) + } + + err = db.PingContext(ctx) + if err != nil { + return nil, fmt.Errorf("unable to ping Oracle connection: %w", err) + } + + return db, nil +} + +// TestOracleSimpleToolEndpoints tests Oracle SQL tool endpoints +func TestOracleSimpleToolEndpoints(t *testing.T) { + sourceConfig := getOracleVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + db, err := initOracleConnection(ctx, OracleUser, OraclePass, OracleConnStr) + if err != nil { + t.Fatalf("unable to create Oracle connection pool: %s", err) + } + + dropAllUserTables(t, ctx, db) + + // create table name with UUID + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam) + teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth) + teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, OracleToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "oracle-execute-sql") + tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, OracleToolType, tmplSelectCombined, tmplSelectFilterCombined, "") + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Get configs for tests + select1Want := "[{\"1\":1}]" + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` + createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"` + mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` + + // Run tests + tests.RunToolGetTest(t) + tests.RunToolInvokeTest(t, select1Want, + tests.DisableOptionalNullParamTest(), + tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"), + tests.DisableArrayTest(), + ) + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) + tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) +} + + func TestOracleTools(t *testing.T) { sourceConfig := getOracleVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -66,11 +217,10 @@ func TestOracleTools(t *testing.T) { var args []string - logger, err := util.NewLogger(os.Stderr, "info") + ctx, err := testutils.ContextWithNewLogger() if err != nil { - t.Fatalf("unable to create logger: %v", err) + t.Fatalf("unexpected error: %s", err) } - ctx = util.ContextWithLogger(ctx, logger) cfg := getOracleConfigFromEnv(t) source, err := cfg.Initialize(ctx, nil) @@ -96,8 +246,8 @@ func TestOracleTools(t *testing.T) { defer teardownTable1(t) // set up data for auth tool - createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(t, tableNameAuth) - teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams...) + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth) + teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) defer teardownTable2(t) // Write config into a file and pass it to command @@ -150,182 +300,131 @@ func TestOracleTools(t *testing.T) { }) } -func getOracleConfigFromEnv(t *testing.T) oracle.Config { - t.Helper() - vars := getOracleVars(t) - port, err := strconv.Atoi(vars["port"].(string)) - if err != nil && vars["port"].(string) != "" { - t.Fatalf("invalid ORACLE_PORT: %v", err) - } - - useOCI, err := strconv.ParseBool(vars["useOCI"].(string)) - if err != nil && vars["useOCI"].(string) != "" { - useOCI = false - } - - return oracle.Config{ - Name: "test-oracle-instance", - Kind: vars["kind"].(string), - User: vars["user"].(string), - Password: vars["password"].(string), - Host: vars["host"].(string), - Port: port, - ServiceName: vars["service"].(string), - WalletLocation: vars["walletLocation"].(string), - TnsAdmin: vars["tnsAdmin"].(string), - UseOCI: useOCI, - } -} +// new integration tests +// TestOracleConnectionPureGoWithWallet tests pure Go driver connection with wallet func TestOracleConnectionPureGoWithWallet(t *testing.T) { - t.Parallel() - // This test expects the connection to fail because the wallet file won't exist. - // It verifies that the walletLocation parameter is correctly passed to the pure Go driver. + t.Parallel() + // This test expects the connection to fail because the wallet file won't exist. + // It verifies that the walletLocation parameter is correctly passed to the pure Go driver. - // Save original env vars and restore them at the end - cleanup := setOracleEnv(t, - oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details - "", // connectionString - "", // tnsAlias - "", // tnsAdmin - "/tmp/nonexistent_wallet", // walletLocation - false, // useOCI - ) - defer cleanup() + // Save original env vars and restore them at the end + cleanup := setOracleEnv(t, + OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details + "", // connectionString + "", // tnsAlias + "", // tnsAdmin + "/tmp/nonexistent_wallet", // walletLocation + false, // useOCI + ) + defer cleanup() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - cfg := getOracleConfigFromEnv(t) - _, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test + cfg := getOracleConfigFromEnv(t) + _, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test - if err == nil { - t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded") - } + if err == nil { + t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded") + } - // Check for error message indicating wallet usage or connection failure related to wallet - // The exact error message might vary depending on the go-ora version and OS. - // We are looking for an error that suggests the wallet path was attempted. - expectedErrorSubstring := "wallet" - if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) { - t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err) - } - t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err) -} - -func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() { - t.Helper() - - original := map[string]string{ - "ORACLE_HOST": os.Getenv("ORACLE_HOST"), - "ORACLE_USER": os.Getenv("ORACLE_USER"), - "ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"), - "ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"), - "ORACLE_PORT": os.Getenv("ORACLE_PORT"), - "ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"), - "ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"), - "ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"), - } - - os.Setenv("ORACLE_HOST", host) - os.Setenv("ORACLE_USER", user) - os.Setenv("ORACLE_PASSWORD", password) - os.Setenv("ORACLE_SERVICE", service) - os.Setenv("ORACLE_PORT", port) - os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin) - os.Setenv("ORACLE_WALLET_LOCATION", walletLocation) - os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI)) - - return func() { - for k, v := range original { - os.Setenv(k, v) - } - } + // Check for error message indicating wallet usage or connection failure related to wallet + // The exact error message might vary depending on the go-ora version and OS. + // We are looking for an error that suggests the wallet path was attempted. + expectedErrorSubstring := "wallet" + if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) { + t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err) + } + t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err) } +// TestOracleConnectionOCI tests OCI driver connection without wallet func TestOracleConnectionOCI(t *testing.T) { - t.Parallel() - // This test verifies that the useOCI=true parameter is correctly passed to the OCI driver. - // It will likely fail if Oracle Instant Client is not installed or configured. + t.Parallel() + // This test verifies that the useOCI=true parameter is correctly passed to the OCI driver. + // It will likely fail if Oracle Instant Client is not installed or configured. - // Save original env vars and restore them at the end - cleanup := setOracleEnv(t, - oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details - "", // connectionString - "", // tnsAlias - "", // tnsAdmin - "", // walletLocation - "true", // useOCI - ) - defer cleanup() + // Save original env vars and restore them at the end + cleanup := setOracleEnv(t, + OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details + "", // connectionString + "", // tnsAlias + "", // tnsAdmin + "", // walletLocation + true, // useOCI + ) + defer cleanup() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - cfg := getOracleConfigFromEnv(t) - _, err := cfg.Initialize(ctx, nil) + cfg := getOracleConfigFromEnv(t) + _, err := cfg.Initialize(ctx, nil) - if err == nil { - t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded") - } + if err == nil { + t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded") + } - // Check for error message indicating OCI driver usage or connection failure related to OCI. - // Common errors include "OCI environment not initialized", "driver: bad connection", etc. - expectedErrorSubstrings := []string{"oci", "driver", "connection"} - foundExpectedError := false - for _, sub := range expectedErrorSubstrings { - if strings.Contains(strings.ToLower(err.Error()), sub) { - foundExpectedError = true - break - } - } - if !foundExpectedError { - t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err) - } - t.Logf("Connection failed as expected (OCI Driver): %v", err) + // Check for error message indicating OCI driver usage or connection failure related to OCI. + // Common errors include "OCI environment not initialized", "driver: bad connection", etc. + expectedErrorSubstrings := []string{"oci", "driver", "connection"} + foundExpectedError := false + for _, sub := range expectedErrorSubstrings { + if strings.Contains(strings.ToLower(err.Error()), sub) { + foundExpectedError = true + break + } + } + if !foundExpectedError { + t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err) + } + t.Logf("Connection failed as expected (OCI Driver): %v", err) } +// TestOracleConnectionOCIWithWallet tests OCI driver connection with TNS Admin and Wallet func TestOracleConnectionOCIWithWallet(t *testing.T) { - t.Parallel() - // This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet. - // It will likely fail due to missing tnsnames.ora and wallet files. + t.Parallel() + // This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet. + // It will likely fail due to missing tnsnames.ora and wallet files. - // Save original env vars and restore them at the end - cleanup := setOracleEnv(t, - "", oracleUser, oraclePassword, "", "", // Unset host/port/service for TNS alias, but keep user/pass - "", // connectionString - "MY_TNS_ALIAS", // tnsAlias - "/tmp/nonexistent_tns_admin", // tnsAdmin - "", // walletLocation - true, // useOCI - ) - defer cleanup() + // Save original env vars and restore them at the end + cleanup := setOracleEnv(t, + "", OracleUser, OraclePass, "", "", // Unset host/port/service for TNS alias, but keep user/pass + "", // connectionString + "MY_TNS_ALIAS", // tnsAlias + "/tmp/nonexistent_tns_admin", // tnsAdmin + "", // walletLocation + true, // useOCI + ) + defer cleanup() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - cfg := getOracleConfigFromEnv(t) - _, err := cfg.Initialize(ctx, nil) + cfg := getOracleConfigFromEnv(t) + _, err := cfg.Initialize(ctx, nil) - if err == nil { - t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded") - } + if err == nil { + t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded") + } - // Check for error message indicating TNS Admin/Wallet usage or connection failure. - expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"} - foundExpectedError := false - for _, sub := range expectedErrorSubstrings { - if strings.Contains(strings.ToLower(err.Error()), sub) { - foundExpectedError = true - break - } - } - if !foundExpectedError { - t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err) - } - t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err) + // Check for error message indicating TNS Admin/Wallet usage or connection failure. + expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"} + foundExpectedError := false + for _, sub := range expectedErrorSubstrings { + if strings.Contains(strings.ToLower(err.Error()), sub) { + foundExpectedError = true + break + } + } + if !foundExpectedError { + t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err) + } + t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err) } +//test utils func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) { err := pool.PingContext(ctx) if err != nil { @@ -381,7 +480,7 @@ func getOracleParamToolInfo(tableName string) (string, string, string, string, s } // getOracleAuthToolInfo returns statements and params for my-auth-tool for Oracle SQL -func getOracleAuthToolInfo(t *testing.T, tableName string) (string, string, string, []any) { +func getOracleAuthToolInfo(tableName string) (string, string, string, []any) { createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255), "email" VARCHAR2(255))`, tableName) // MODIFIED: Use a PL/SQL block for multiple inserts @@ -432,4 +531,4 @@ func dropAllUserTables(t *testing.T, ctx context.Context, db *sql.DB) { t.Logf("failed to drop table %s: %v", tableName, err) } } -} +} \ No newline at end of file