feat: refactor Oracle integration tests and enhance environment variable handling

This commit is contained in:
RUN
2026-01-28 21:19:25 +01:00
parent 418d6d791e
commit a48101b3c5

View File

@@ -16,49 +16,200 @@ import (
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/tests"
)
var (
OracleSourceType = "oracle"
OracleToolType = "oracle-sql"
OracleHost = os.Getenv("ORACLE_HOST")
OraclePort = os.Getenv("ORACLE_PORT")
OracleUser = os.Getenv("ORACLE_USER")
OraclePass = os.Getenv("ORACLE_PASS")
OracleServerName = os.Getenv("ORACLE_SERVER_NAME")
OraclePort = os.Getenv("ORACLE_PORT")
OracleService = os.Getenv("ORACLE_SERVICE")
OracleWalletPath = os.Getenv("ORACLE_WALLET_PATH")
oracleUseOCI = os.Getenv("ORACLE_USE_OCI")
OracleUseOCI = os.Getenv("ORACLE_USE_OCI")
OracleWalletLocation = os.Getenv("ORACLE_WALLET_LOCATION")
OracleTnsAdmin = os.Getenv("ORACLE_TNS_ADMIN")
OracleConnStr = fmt.Sprintf(
"%s:%s/%s", OracleHost, "1521", OracleServerName)
"%s:%s/%s", OracleHost, "1521", OracleServerName) // Default port 1521??
)
func getOracleVars(t *testing.T) map[string]any {
switch "" {
case oracleHost:
case OracleHost:
t.Fatal("'ORACLE_HOST not set")
case oracleUser:
case OracleUser:
t.Fatal("'ORACLE_USER' not set")
case oraclePassword:
t.Fatal("'ORACLE_PASSWORD' not set")
case oraclePort:
t.Fatal("'ORACLE_PORT' not set")
case oracleService:
t.Fatal("'ORACLE_SERVICE' not set")
case OraclePass:
t.Fatal("'ORACLE_PASS' not set")
case OracleServerName:
t.Fatal("'ORACLE_SERVER_NAME' not set")
}
return map[string]any{
"type": OracleSourceType,
"connectionString": OracleConnStr,
"useOCI": true,
"useOCI": OracleUseOCI,
"walletLocation": OracleWalletLocation,
"tnsAdmin": OracleTnsAdmin,
"host": OracleHost,
"port": OraclePort,
"service": OracleServerName,
"user": OracleUser,
"password": OraclePass,
}
}
// getOracleConfigFromEnv constructs an oracle.Config from environment variables.
func getOracleConfigFromEnv(t *testing.T) oracle.Config {
t.Helper()
vars := getOracleVars(t)
port, err := strconv.Atoi(vars["port"].(string))
if err != nil && vars["port"].(string) != "" {
t.Fatalf("invalid ORACLE_PORT: %v", err)
}
useOCI, err := strconv.ParseBool(vars["ORACLE_USE_OCI"].(string))
if err != nil && vars["ORACLE_USE_OCI"].(string) != "" {
useOCI = false
}
return oracle.Config{
Name: "test-oracle-instance",
Kind: vars["kind"].(string),
User: vars["user"].(string),
Password: vars["password"].(string),
Host: vars["host"].(string),
Port: port,
ServiceName: vars["service"].(string),
WalletLocation: vars["walletLocation"].(string),
TnsAdmin: vars["tnsAdmin"].(string),
UseOCI: useOCI,
}
}
// setOracleEnv sets Oracle-related environment variables for testing and returns a cleanup function.
func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() {
t.Helper()
original := map[string]string{
"ORACLE_HOST": os.Getenv("ORACLE_HOST"),
"ORACLE_USER": os.Getenv("ORACLE_USER"),
"ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"),
"ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"),
"ORACLE_PORT": os.Getenv("ORACLE_PORT"),
"ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"),
"ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"),
"ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"),
}
os.Setenv("ORACLE_HOST", host)
os.Setenv("ORACLE_USER", user)
os.Setenv("ORACLE_PASSWORD", password)
os.Setenv("ORACLE_SERVICE", service)
os.Setenv("ORACLE_PORT", port)
os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin)
os.Setenv("ORACLE_WALLET_LOCATION", walletLocation)
os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI))
return func() {
for k, v := range original {
os.Setenv(k, v)
}
}
}
// Copied over from oracle.go
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
// Build the full Oracle connection string for godror driver
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
user, pass, connStr)
db, err := sql.Open("godror", fullConnStr)
if err != nil {
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
}
err = db.PingContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to ping Oracle connection: %w", err)
}
return db, nil
}
// TestOracleSimpleToolEndpoints tests Oracle SQL tool endpoints
func TestOracleSimpleToolEndpoints(t *testing.T) {
sourceConfig := getOracleVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
db, err := initOracleConnection(ctx, OracleUser, OraclePass, OracleConnStr)
if err != nil {
t.Fatalf("unable to create Oracle connection pool: %s", err)
}
dropAllUserTables(t, ctx, db)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, OracleToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "oracle-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, OracleToolType, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Get configs for tests
select1Want := "[{\"1\":1}]"
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
// Run tests
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.DisableOptionalNullParamTest(),
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
tests.DisableArrayTest(),
)
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}
func TestOracleTools(t *testing.T) {
sourceConfig := getOracleVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
@@ -66,11 +217,10 @@ func TestOracleTools(t *testing.T) {
var args []string
logger, err := util.NewLogger(os.Stderr, "info")
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unable to create logger: %v", err)
t.Fatalf("unexpected error: %s", err)
}
ctx = util.ContextWithLogger(ctx, logger)
cfg := getOracleConfigFromEnv(t)
source, err := cfg.Initialize(ctx, nil)
@@ -96,8 +246,8 @@ func TestOracleTools(t *testing.T) {
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(t, tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams...)
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Write config into a file and pass it to command
@@ -150,182 +300,131 @@ func TestOracleTools(t *testing.T) {
})
}
func getOracleConfigFromEnv(t *testing.T) oracle.Config {
t.Helper()
vars := getOracleVars(t)
port, err := strconv.Atoi(vars["port"].(string))
if err != nil && vars["port"].(string) != "" {
t.Fatalf("invalid ORACLE_PORT: %v", err)
}
useOCI, err := strconv.ParseBool(vars["useOCI"].(string))
if err != nil && vars["useOCI"].(string) != "" {
useOCI = false
}
return oracle.Config{
Name: "test-oracle-instance",
Kind: vars["kind"].(string),
User: vars["user"].(string),
Password: vars["password"].(string),
Host: vars["host"].(string),
Port: port,
ServiceName: vars["service"].(string),
WalletLocation: vars["walletLocation"].(string),
TnsAdmin: vars["tnsAdmin"].(string),
UseOCI: useOCI,
}
}
// new integration tests
// TestOracleConnectionPureGoWithWallet tests pure Go driver connection with wallet
func TestOracleConnectionPureGoWithWallet(t *testing.T) {
t.Parallel()
// This test expects the connection to fail because the wallet file won't exist.
// It verifies that the walletLocation parameter is correctly passed to the pure Go driver.
t.Parallel()
// This test expects the connection to fail because the wallet file won't exist.
// It verifies that the walletLocation parameter is correctly passed to the pure Go driver.
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details
"", // connectionString
"", // tnsAlias
"", // tnsAdmin
"/tmp/nonexistent_wallet", // walletLocation
false, // useOCI
)
defer cleanup()
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details
"", // connectionString
"", // tnsAlias
"", // tnsAdmin
"/tmp/nonexistent_wallet", // walletLocation
false, // useOCI
)
defer cleanup()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test
if err == nil {
t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded")
}
if err == nil {
t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded")
}
// Check for error message indicating wallet usage or connection failure related to wallet
// The exact error message might vary depending on the go-ora version and OS.
// We are looking for an error that suggests the wallet path was attempted.
expectedErrorSubstring := "wallet"
if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) {
t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err)
}
t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err)
}
func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() {
t.Helper()
original := map[string]string{
"ORACLE_HOST": os.Getenv("ORACLE_HOST"),
"ORACLE_USER": os.Getenv("ORACLE_USER"),
"ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"),
"ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"),
"ORACLE_PORT": os.Getenv("ORACLE_PORT"),
"ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"),
"ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"),
"ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"),
}
os.Setenv("ORACLE_HOST", host)
os.Setenv("ORACLE_USER", user)
os.Setenv("ORACLE_PASSWORD", password)
os.Setenv("ORACLE_SERVICE", service)
os.Setenv("ORACLE_PORT", port)
os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin)
os.Setenv("ORACLE_WALLET_LOCATION", walletLocation)
os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI))
return func() {
for k, v := range original {
os.Setenv(k, v)
}
}
// Check for error message indicating wallet usage or connection failure related to wallet
// The exact error message might vary depending on the go-ora version and OS.
// We are looking for an error that suggests the wallet path was attempted.
expectedErrorSubstring := "wallet"
if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) {
t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err)
}
t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err)
}
// TestOracleConnectionOCI tests OCI driver connection without wallet
func TestOracleConnectionOCI(t *testing.T) {
t.Parallel()
// This test verifies that the useOCI=true parameter is correctly passed to the OCI driver.
// It will likely fail if Oracle Instant Client is not installed or configured.
t.Parallel()
// This test verifies that the useOCI=true parameter is correctly passed to the OCI driver.
// It will likely fail if Oracle Instant Client is not installed or configured.
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details
"", // connectionString
"", // tnsAlias
"", // tnsAdmin
"", // walletLocation
"true", // useOCI
)
defer cleanup()
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details
"", // connectionString
"", // tnsAlias
"", // tnsAdmin
"", // walletLocation
true, // useOCI
)
defer cleanup()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil)
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil)
if err == nil {
t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded")
}
if err == nil {
t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded")
}
// Check for error message indicating OCI driver usage or connection failure related to OCI.
// Common errors include "OCI environment not initialized", "driver: bad connection", etc.
expectedErrorSubstrings := []string{"oci", "driver", "connection"}
foundExpectedError := false
for _, sub := range expectedErrorSubstrings {
if strings.Contains(strings.ToLower(err.Error()), sub) {
foundExpectedError = true
break
}
}
if !foundExpectedError {
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
}
t.Logf("Connection failed as expected (OCI Driver): %v", err)
// Check for error message indicating OCI driver usage or connection failure related to OCI.
// Common errors include "OCI environment not initialized", "driver: bad connection", etc.
expectedErrorSubstrings := []string{"oci", "driver", "connection"}
foundExpectedError := false
for _, sub := range expectedErrorSubstrings {
if strings.Contains(strings.ToLower(err.Error()), sub) {
foundExpectedError = true
break
}
}
if !foundExpectedError {
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
}
t.Logf("Connection failed as expected (OCI Driver): %v", err)
}
// TestOracleConnectionOCIWithWallet tests OCI driver connection with TNS Admin and Wallet
func TestOracleConnectionOCIWithWallet(t *testing.T) {
t.Parallel()
// This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet.
// It will likely fail due to missing tnsnames.ora and wallet files.
t.Parallel()
// This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet.
// It will likely fail due to missing tnsnames.ora and wallet files.
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
"", oracleUser, oraclePassword, "", "", // Unset host/port/service for TNS alias, but keep user/pass
"", // connectionString
"MY_TNS_ALIAS", // tnsAlias
"/tmp/nonexistent_tns_admin", // tnsAdmin
"", // walletLocation
true, // useOCI
)
defer cleanup()
// Save original env vars and restore them at the end
cleanup := setOracleEnv(t,
"", OracleUser, OraclePass, "", "", // Unset host/port/service for TNS alias, but keep user/pass
"", // connectionString
"MY_TNS_ALIAS", // tnsAlias
"/tmp/nonexistent_tns_admin", // tnsAdmin
"", // walletLocation
true, // useOCI
)
defer cleanup()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil)
cfg := getOracleConfigFromEnv(t)
_, err := cfg.Initialize(ctx, nil)
if err == nil {
t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded")
}
if err == nil {
t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded")
}
// Check for error message indicating TNS Admin/Wallet usage or connection failure.
expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"}
foundExpectedError := false
for _, sub := range expectedErrorSubstrings {
if strings.Contains(strings.ToLower(err.Error()), sub) {
foundExpectedError = true
break
}
}
if !foundExpectedError {
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
}
t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err)
// Check for error message indicating TNS Admin/Wallet usage or connection failure.
expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"}
foundExpectedError := false
for _, sub := range expectedErrorSubstrings {
if strings.Contains(strings.ToLower(err.Error()), sub) {
foundExpectedError = true
break
}
}
if !foundExpectedError {
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
}
t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err)
}
//test utils
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
err := pool.PingContext(ctx)
if err != nil {
@@ -381,7 +480,7 @@ func getOracleParamToolInfo(tableName string) (string, string, string, string, s
}
// getOracleAuthToolInfo returns statements and params for my-auth-tool for Oracle SQL
func getOracleAuthToolInfo(t *testing.T, tableName string) (string, string, string, []any) {
func getOracleAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255), "email" VARCHAR2(255))`, tableName)
// MODIFIED: Use a PL/SQL block for multiple inserts
@@ -432,4 +531,4 @@ func dropAllUserTables(t *testing.T, ctx context.Context, db *sql.DB) {
t.Logf("failed to drop table %s: %v", tableName, err)
}
}
}
}