mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-13 08:35:15 -05:00
feat: refactor Oracle integration tests and enhance environment variable handling
This commit is contained in:
@@ -16,49 +16,200 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
|
||||
var (
|
||||
OracleSourceType = "oracle"
|
||||
OracleToolType = "oracle-sql"
|
||||
OracleHost = os.Getenv("ORACLE_HOST")
|
||||
OraclePort = os.Getenv("ORACLE_PORT")
|
||||
OracleUser = os.Getenv("ORACLE_USER")
|
||||
OraclePass = os.Getenv("ORACLE_PASS")
|
||||
OracleServerName = os.Getenv("ORACLE_SERVER_NAME")
|
||||
OraclePort = os.Getenv("ORACLE_PORT")
|
||||
OracleService = os.Getenv("ORACLE_SERVICE")
|
||||
OracleWalletPath = os.Getenv("ORACLE_WALLET_PATH")
|
||||
oracleUseOCI = os.Getenv("ORACLE_USE_OCI")
|
||||
|
||||
OracleUseOCI = os.Getenv("ORACLE_USE_OCI")
|
||||
OracleWalletLocation = os.Getenv("ORACLE_WALLET_LOCATION")
|
||||
OracleTnsAdmin = os.Getenv("ORACLE_TNS_ADMIN")
|
||||
|
||||
OracleConnStr = fmt.Sprintf(
|
||||
"%s:%s/%s", OracleHost, "1521", OracleServerName)
|
||||
"%s:%s/%s", OracleHost, "1521", OracleServerName) // Default port 1521??
|
||||
)
|
||||
|
||||
func getOracleVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case oracleHost:
|
||||
case OracleHost:
|
||||
t.Fatal("'ORACLE_HOST not set")
|
||||
case oracleUser:
|
||||
case OracleUser:
|
||||
t.Fatal("'ORACLE_USER' not set")
|
||||
case oraclePassword:
|
||||
t.Fatal("'ORACLE_PASSWORD' not set")
|
||||
case oraclePort:
|
||||
t.Fatal("'ORACLE_PORT' not set")
|
||||
case oracleService:
|
||||
t.Fatal("'ORACLE_SERVICE' not set")
|
||||
case OraclePass:
|
||||
t.Fatal("'ORACLE_PASS' not set")
|
||||
case OracleServerName:
|
||||
t.Fatal("'ORACLE_SERVER_NAME' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"type": OracleSourceType,
|
||||
"connectionString": OracleConnStr,
|
||||
"useOCI": true,
|
||||
"useOCI": OracleUseOCI,
|
||||
"walletLocation": OracleWalletLocation,
|
||||
"tnsAdmin": OracleTnsAdmin,
|
||||
"host": OracleHost,
|
||||
"port": OraclePort,
|
||||
"service": OracleServerName,
|
||||
"user": OracleUser,
|
||||
"password": OraclePass,
|
||||
}
|
||||
}
|
||||
|
||||
// getOracleConfigFromEnv constructs an oracle.Config from environment variables.
|
||||
func getOracleConfigFromEnv(t *testing.T) oracle.Config {
|
||||
t.Helper()
|
||||
vars := getOracleVars(t)
|
||||
|
||||
port, err := strconv.Atoi(vars["port"].(string))
|
||||
if err != nil && vars["port"].(string) != "" {
|
||||
t.Fatalf("invalid ORACLE_PORT: %v", err)
|
||||
}
|
||||
|
||||
useOCI, err := strconv.ParseBool(vars["ORACLE_USE_OCI"].(string))
|
||||
if err != nil && vars["ORACLE_USE_OCI"].(string) != "" {
|
||||
useOCI = false
|
||||
}
|
||||
|
||||
return oracle.Config{
|
||||
Name: "test-oracle-instance",
|
||||
Kind: vars["kind"].(string),
|
||||
User: vars["user"].(string),
|
||||
Password: vars["password"].(string),
|
||||
Host: vars["host"].(string),
|
||||
Port: port,
|
||||
ServiceName: vars["service"].(string),
|
||||
WalletLocation: vars["walletLocation"].(string),
|
||||
TnsAdmin: vars["tnsAdmin"].(string),
|
||||
UseOCI: useOCI,
|
||||
}
|
||||
}
|
||||
|
||||
// setOracleEnv sets Oracle-related environment variables for testing and returns a cleanup function.
|
||||
func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() {
|
||||
t.Helper()
|
||||
|
||||
original := map[string]string{
|
||||
"ORACLE_HOST": os.Getenv("ORACLE_HOST"),
|
||||
"ORACLE_USER": os.Getenv("ORACLE_USER"),
|
||||
"ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"),
|
||||
"ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"),
|
||||
"ORACLE_PORT": os.Getenv("ORACLE_PORT"),
|
||||
"ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"),
|
||||
"ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"),
|
||||
"ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"),
|
||||
}
|
||||
|
||||
os.Setenv("ORACLE_HOST", host)
|
||||
os.Setenv("ORACLE_USER", user)
|
||||
os.Setenv("ORACLE_PASSWORD", password)
|
||||
os.Setenv("ORACLE_SERVICE", service)
|
||||
os.Setenv("ORACLE_PORT", port)
|
||||
os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin)
|
||||
os.Setenv("ORACLE_WALLET_LOCATION", walletLocation)
|
||||
os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI))
|
||||
|
||||
return func() {
|
||||
for k, v := range original {
|
||||
os.Setenv(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from oracle.go
|
||||
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
|
||||
// Build the full Oracle connection string for godror driver
|
||||
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
user, pass, connStr)
|
||||
|
||||
db, err := sql.Open("godror", fullConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to ping Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// TestOracleSimpleToolEndpoints tests Oracle SQL tool endpoints
|
||||
func TestOracleSimpleToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getOracleVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
db, err := initOracleConnection(ctx, OracleUser, OraclePass, OracleConnStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Oracle connection pool: %s", err)
|
||||
}
|
||||
|
||||
dropAllUserTables(t, ctx, db)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, OracleToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "oracle-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, OracleToolType, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Get configs for tests
|
||||
select1Want := "[{\"1\":1}]"
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
|
||||
tests.DisableArrayTest(),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
}
|
||||
|
||||
|
||||
func TestOracleTools(t *testing.T) {
|
||||
sourceConfig := getOracleVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
@@ -66,11 +217,10 @@ func TestOracleTools(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
logger, err := util.NewLogger(os.Stderr, "info")
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create logger: %v", err)
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.ContextWithLogger(ctx, logger)
|
||||
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
source, err := cfg.Initialize(ctx, nil)
|
||||
@@ -96,8 +246,8 @@ func TestOracleTools(t *testing.T) {
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(t, tableNameAuth)
|
||||
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams...)
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
@@ -150,182 +300,131 @@ func TestOracleTools(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func getOracleConfigFromEnv(t *testing.T) oracle.Config {
|
||||
t.Helper()
|
||||
vars := getOracleVars(t)
|
||||
port, err := strconv.Atoi(vars["port"].(string))
|
||||
if err != nil && vars["port"].(string) != "" {
|
||||
t.Fatalf("invalid ORACLE_PORT: %v", err)
|
||||
}
|
||||
|
||||
useOCI, err := strconv.ParseBool(vars["useOCI"].(string))
|
||||
if err != nil && vars["useOCI"].(string) != "" {
|
||||
useOCI = false
|
||||
}
|
||||
|
||||
return oracle.Config{
|
||||
Name: "test-oracle-instance",
|
||||
Kind: vars["kind"].(string),
|
||||
User: vars["user"].(string),
|
||||
Password: vars["password"].(string),
|
||||
Host: vars["host"].(string),
|
||||
Port: port,
|
||||
ServiceName: vars["service"].(string),
|
||||
WalletLocation: vars["walletLocation"].(string),
|
||||
TnsAdmin: vars["tnsAdmin"].(string),
|
||||
UseOCI: useOCI,
|
||||
}
|
||||
}
|
||||
// new integration tests
|
||||
|
||||
// TestOracleConnectionPureGoWithWallet tests pure Go driver connection with wallet
|
||||
func TestOracleConnectionPureGoWithWallet(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test expects the connection to fail because the wallet file won't exist.
|
||||
// It verifies that the walletLocation parameter is correctly passed to the pure Go driver.
|
||||
t.Parallel()
|
||||
// This test expects the connection to fail because the wallet file won't exist.
|
||||
// It verifies that the walletLocation parameter is correctly passed to the pure Go driver.
|
||||
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details
|
||||
"", // connectionString
|
||||
"", // tnsAlias
|
||||
"", // tnsAdmin
|
||||
"/tmp/nonexistent_wallet", // walletLocation
|
||||
false, // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details
|
||||
"", // connectionString
|
||||
"", // tnsAlias
|
||||
"", // tnsAdmin
|
||||
"/tmp/nonexistent_wallet", // walletLocation
|
||||
false, // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil) // Pass nil for tracer as it's not critical for this test
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail with non-existent wallet, but it succeeded")
|
||||
}
|
||||
|
||||
// Check for error message indicating wallet usage or connection failure related to wallet
|
||||
// The exact error message might vary depending on the go-ora version and OS.
|
||||
// We are looking for an error that suggests the wallet path was attempted.
|
||||
expectedErrorSubstring := "wallet"
|
||||
if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) {
|
||||
t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err)
|
||||
}
|
||||
|
||||
func setOracleEnv(t *testing.T, host, user, password, service, port, connStr, tnsAlias, tnsAdmin, walletLocation string, useOCI bool) func() {
|
||||
t.Helper()
|
||||
|
||||
original := map[string]string{
|
||||
"ORACLE_HOST": os.Getenv("ORACLE_HOST"),
|
||||
"ORACLE_USER": os.Getenv("ORACLE_USER"),
|
||||
"ORACLE_PASSWORD": os.Getenv("ORACLE_PASSWORD"),
|
||||
"ORACLE_SERVICE": os.Getenv("ORACLE_SERVICE"),
|
||||
"ORACLE_PORT": os.Getenv("ORACLE_PORT"),
|
||||
"ORACLE_TNS_ADMIN": os.Getenv("ORACLE_TNS_ADMIN"),
|
||||
"ORACLE_WALLET_LOCATION": os.Getenv("ORACLE_WALLET_LOCATION"),
|
||||
"ORACLE_USE_OCI": os.Getenv("ORACLE_USE_OCI"),
|
||||
}
|
||||
|
||||
os.Setenv("ORACLE_HOST", host)
|
||||
os.Setenv("ORACLE_USER", user)
|
||||
os.Setenv("ORACLE_PASSWORD", password)
|
||||
os.Setenv("ORACLE_SERVICE", service)
|
||||
os.Setenv("ORACLE_PORT", port)
|
||||
os.Setenv("ORACLE_TNS_ADMIN", tnsAdmin)
|
||||
os.Setenv("ORACLE_WALLET_LOCATION", walletLocation)
|
||||
os.Setenv("ORACLE_USE_OCI", fmt.Sprintf("%v", useOCI))
|
||||
|
||||
return func() {
|
||||
for k, v := range original {
|
||||
os.Setenv(k, v)
|
||||
}
|
||||
}
|
||||
// Check for error message indicating wallet usage or connection failure related to wallet
|
||||
// The exact error message might vary depending on the go-ora version and OS.
|
||||
// We are looking for an error that suggests the wallet path was attempted.
|
||||
expectedErrorSubstring := "wallet"
|
||||
if !strings.Contains(strings.ToLower(err.Error()), expectedErrorSubstring) {
|
||||
t.Errorf("Expected error message to contain '%s' (case-insensitive) but got: %v", expectedErrorSubstring, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (Pure Go with Wallet): %v", err)
|
||||
}
|
||||
|
||||
// TestOracleConnectionOCI tests OCI driver connection without wallet
|
||||
func TestOracleConnectionOCI(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies that the useOCI=true parameter is correctly passed to the OCI driver.
|
||||
// It will likely fail if Oracle Instant Client is not installed or configured.
|
||||
t.Parallel()
|
||||
// This test verifies that the useOCI=true parameter is correctly passed to the OCI driver.
|
||||
// It will likely fail if Oracle Instant Client is not installed or configured.
|
||||
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
oracleHost, oracleUser, oraclePassword, oracleService, oraclePort, // Use existing base connection details
|
||||
"", // connectionString
|
||||
"", // tnsAlias
|
||||
"", // tnsAdmin
|
||||
"", // walletLocation
|
||||
"true", // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
OracleHost, OracleUser, OraclePass, OracleServerName, OraclePort, // Use existing base connection details
|
||||
"", // connectionString
|
||||
"", // tnsAlias
|
||||
"", // tnsAdmin
|
||||
"", // walletLocation
|
||||
true, // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil)
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail (OCI driver without Instant Client), but it succeeded")
|
||||
}
|
||||
|
||||
// Check for error message indicating OCI driver usage or connection failure related to OCI.
|
||||
// Common errors include "OCI environment not initialized", "driver: bad connection", etc.
|
||||
expectedErrorSubstrings := []string{"oci", "driver", "connection"}
|
||||
foundExpectedError := false
|
||||
for _, sub := range expectedErrorSubstrings {
|
||||
if strings.Contains(strings.ToLower(err.Error()), sub) {
|
||||
foundExpectedError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExpectedError {
|
||||
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (OCI Driver): %v", err)
|
||||
// Check for error message indicating OCI driver usage or connection failure related to OCI.
|
||||
// Common errors include "OCI environment not initialized", "driver: bad connection", etc.
|
||||
expectedErrorSubstrings := []string{"oci", "driver", "connection"}
|
||||
foundExpectedError := false
|
||||
for _, sub := range expectedErrorSubstrings {
|
||||
if strings.Contains(strings.ToLower(err.Error()), sub) {
|
||||
foundExpectedError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExpectedError {
|
||||
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (OCI Driver): %v", err)
|
||||
}
|
||||
|
||||
// TestOracleConnectionOCIWithWallet tests OCI driver connection with TNS Admin and Wallet
|
||||
func TestOracleConnectionOCIWithWallet(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet.
|
||||
// It will likely fail due to missing tnsnames.ora and wallet files.
|
||||
t.Parallel()
|
||||
// This test verifies that useOCI=true and tnsAdmin parameters are correctly passed for OCI wallet.
|
||||
// It will likely fail due to missing tnsnames.ora and wallet files.
|
||||
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
"", oracleUser, oraclePassword, "", "", // Unset host/port/service for TNS alias, but keep user/pass
|
||||
"", // connectionString
|
||||
"MY_TNS_ALIAS", // tnsAlias
|
||||
"/tmp/nonexistent_tns_admin", // tnsAdmin
|
||||
"", // walletLocation
|
||||
true, // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
// Save original env vars and restore them at the end
|
||||
cleanup := setOracleEnv(t,
|
||||
"", OracleUser, OraclePass, "", "", // Unset host/port/service for TNS alias, but keep user/pass
|
||||
"", // connectionString
|
||||
"MY_TNS_ALIAS", // tnsAlias
|
||||
"/tmp/nonexistent_tns_admin", // tnsAdmin
|
||||
"", // walletLocation
|
||||
true, // useOCI
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil)
|
||||
cfg := getOracleConfigFromEnv(t)
|
||||
_, err := cfg.Initialize(ctx, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected connection to fail (OCI driver with TNS Admin/Wallet), but it succeeded")
|
||||
}
|
||||
|
||||
// Check for error message indicating TNS Admin/Wallet usage or connection failure.
|
||||
expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"}
|
||||
foundExpectedError := false
|
||||
for _, sub := range expectedErrorSubstrings {
|
||||
if strings.Contains(strings.ToLower(err.Error()), sub) {
|
||||
foundExpectedError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExpectedError {
|
||||
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err)
|
||||
// Check for error message indicating TNS Admin/Wallet usage or connection failure.
|
||||
expectedErrorSubstrings := []string{"tns", "wallet", "oci", "driver", "connection"}
|
||||
foundExpectedError := false
|
||||
for _, sub := range expectedErrorSubstrings {
|
||||
if strings.Contains(strings.ToLower(err.Error()), sub) {
|
||||
foundExpectedError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExpectedError {
|
||||
t.Errorf("Expected error message to contain one of %v (case-insensitive) but got: %v", expectedErrorSubstrings, err)
|
||||
}
|
||||
t.Logf("Connection failed as expected (OCI Driver with TNS Admin/Wallet): %v", err)
|
||||
}
|
||||
|
||||
//test utils
|
||||
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||
err := pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
@@ -381,7 +480,7 @@ func getOracleParamToolInfo(tableName string) (string, string, string, string, s
|
||||
}
|
||||
|
||||
// getOracleAuthToolInfo returns statements and params for my-auth-tool for Oracle SQL
|
||||
func getOracleAuthToolInfo(t *testing.T, tableName string) (string, string, string, []any) {
|
||||
func getOracleAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255), "email" VARCHAR2(255))`, tableName)
|
||||
|
||||
// MODIFIED: Use a PL/SQL block for multiple inserts
|
||||
@@ -432,4 +531,4 @@ func dropAllUserTables(t *testing.T, ctx context.Context, db *sql.DB) {
|
||||
t.Logf("failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user