From 8ea39ec32fbbaa97939c626fec8c5d86040ed464 Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:02:17 -0500 Subject: [PATCH] feat(sources/oracle): Add Oracle OCI and Wallet support (#1945) Previously we used go-ora (a pure Go Oracle driver) because our release pipeline did not support cross-compilation with CGO. Now that it's fixed, we want to add support for Oracle OCI driver for advanced features including digital wallet etc. Users will be able to configure a source to use OCI by specifying a `UseOCI: true` field. The source defaults to use the pure Go driver otherwise. Oracle Wallet: - OCI users should use the `tnsAdmin` to set the wallet location - Non-OCI users can should use the `walletLocation` field. fix: https://github.com/googleapis/genai-toolbox/issues/1779 --- .ci/continuous.release.cloudbuild.yaml | 2 +- .ci/integration.cloudbuild.yaml | 27 ++- docs/en/resources/sources/oracle.md | 88 ++++++-- go.mod | 4 + go.sum | 14 ++ internal/sources/oracle/oracle.go | 79 +++++-- internal/sources/oracle/oracle_test.go | 200 ++++++++++++++++++ .../oracleexecutesql/oracleexecutesql.go | 2 +- .../oracleexecutesql/oracleexecutesql_test.go | 82 +++++++ .../tools/oracle/oraclesql/oraclesql_test.go | 85 ++++++++ tests/oracle/oracle_integration_test.go | 11 +- 11 files changed, 547 insertions(+), 47 deletions(-) create mode 100644 internal/sources/oracle/oracle_test.go create mode 100644 internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go create mode 100644 internal/tools/oracle/oraclesql/oraclesql_test.go diff --git a/.ci/continuous.release.cloudbuild.yaml b/.ci/continuous.release.cloudbuild.yaml index b73000aa1b..0025d46719 100644 --- a/.ci/continuous.release.cloudbuild.yaml +++ b/.ci/continuous.release.cloudbuild.yaml @@ -305,4 +305,4 @@ substitutions: _AR_HOSTNAME: ${_REGION}-docker.pkg.dev _AR_REPO_NAME: toolbox-dev _BUCKET_NAME: genai-toolbox-dev - _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox + _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox \ No newline at end of file diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index c0d7909c9d..b424a490e7 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -846,8 +846,8 @@ steps: cassandra - id: "oracle" - name: golang:1 - waitFor: ["compile-test-binary"] + name: ghcr.io/oracle/oraclelinux9-instantclient:23 + waitFor: ["install-dependencies"] entrypoint: /bin/bash env: - "GOPATH=/gopath" @@ -860,10 +860,25 @@ steps: args: - -c - | - .ci/test_with_coverage.sh \ - "Oracle" \ - oracle \ - oracle + # Install the C compiler and Oracle SDK headers needed for cgo + dnf install -y gcc oracle-instantclient-devel + # Install Go + curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz" + tar -C /usr/local -xzf go.tar.gz + export PATH="/usr/local/go/bin:$$PATH" + + go test -v ./internal/sources/oracle/... \ + -coverprofile=oracle_coverage.out \ + -coverpkg=./internal/sources/oracle/...,./internal/tools/oracle/... + + # Coverage check + total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}') + echo "Oracle total coverage: $total_coverage" + coverage_numeric=$(echo "$total_coverage" | sed 's/%//') + if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then + echo "Coverage failure: $total_coverage is below 30%." + exit 1 + fi - id: "serverless-spark" name: golang:1 diff --git a/docs/en/resources/sources/oracle.md b/docs/en/resources/sources/oracle.md index 4932ea6e22..51fa18fe13 100644 --- a/docs/en/resources/sources/oracle.md +++ b/docs/en/resources/sources/oracle.md @@ -18,10 +18,10 @@ DW) database workloads. ## Available Tools - [`oracle-sql`](../tools/oracle/oracle-sql.md) - Execute pre-defined prepared SQL queries in Oracle. + Execute pre-defined prepared SQL queries in Oracle. - [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md) - Run parameterized SQL queries in Oracle. + Run parameterized SQL queries in Oracle. ## Requirements @@ -33,6 +33,25 @@ user][oracle-users] to log in to the database with the necessary permissions. [oracle-users]: https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html +### Oracle Driver Requirement (Conditional) + +The Oracle source offers two connection drivers: + +1. **Pure Go Driver (`useOCI: false`, default):** Uses the `go-ora` library. + This driver is simpler and does not require any local Oracle software + installation, but it **lacks support for advanced features** like Oracle + Wallets or Kerberos authentication. + +2. **OCI-Based Driver (`useOCI: true`):** Uses the `godror` library, which + provides access to **advanced Oracle features** like Digital Wallet support. + +If you set `useOCI: true`, you **must** install the **Oracle Instant Client** +libraries on the machine where this tool runs. + +You can download the Instant Client from the official Oracle website: [Oracle +Instant Client +Downloads](https://www.oracle.com/database/technologies/instant-client/downloads.html) + ## Connection Methods You can configure the connection to your Oracle database using one of the @@ -66,12 +85,15 @@ using a TNS (Transparent Network Substrate) alias. containing it. This setting will override the `TNS_ADMIN` environment variable. -## Example +## Examples + +This example demonstrates the four connection methods you could choose from: ```yaml sources: my-oracle-source: kind: oracle + # --- Choose one connection method --- # 1. Host, Port, and Service Name host: 127.0.0.1 @@ -88,6 +110,43 @@ sources: user: ${USER_NAME} password: ${PASSWORD} + # Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client) +``` + +### Using an Oracle Wallet + +Oracle Wallet allows you to store credentails used for database connection. Depending whether you are using an OCI-based driver, the wallet configuration is different. + +#### Pure Go Driver (`useOCI: false`) - Oracle Wallet + +The `go-ora` driver uses the `walletLocation` field to connect to a database secured with an Oracle Wallet without standard username and password. + +```yaml +sources: + pure-go-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + # The TNS Alias is often required to connect to a service registered in tnsnames.ora + tnsAlias: "SECURE_DB_ALIAS" + walletLocation: "/path/to/my/wallet/directory" +``` + +#### OCI-Based Driver (`useOCI: true`) - Oracle Wallet + +For the OCI-based driver, wallet authentication is triggered by setting tnsAdmin to the wallet directory and connecting via a tnsAlias. + +```yaml +sources: + oci-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + tnsAlias: "WALLET_DB_ALIAS" + tnsAdmin: "/opt/oracle/wallet" # Directory containing tnsnames.ora, sqlnet.ora, and wallet files + useOCI: true ``` {{< notice tip >}} @@ -97,14 +156,15 @@ instead of hardcoding your secrets into the configuration file. ## Reference -| **field** | **type** | **required** | **description** | -|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------| -| kind | string | true | Must be "oracle". | -| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | -| password | string | true | Password of the Oracle user (e.g. "my-password"). | -| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | -| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | -| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | -| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | -| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | -| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| **field** | **type** | **required** | **description** | +|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "oracle". | +| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | +| password | string | true | Password of the Oracle user (e.g. "my-password"). | +| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | +| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | +| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | +| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | +| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | +| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). | diff --git a/go.mod b/go.mod index 074c18a5d6..e10d45187e 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/go-playground/validator/v10 v10.28.0 github.com/go-sql-driver/mysql v1.9.3 github.com/goccy/go-yaml v1.18.0 + github.com/godror/godror v0.49.4 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.6 @@ -91,6 +92,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect github.com/PuerkitoBio/goquery v1.10.3 // indirect + github.com/VictoriaMetrics/easyproto v0.1.4 // indirect github.com/ajg/form v1.5.1 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -107,11 +109,13 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/godror/knownpb v0.3.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect diff --git a/go.sum b/go.sum index 6fa294f79c..3c270b9ba7 100644 --- a/go.sum +++ b/go.sum @@ -683,6 +683,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8 github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= +github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4= +github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc= +github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc= +github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= @@ -884,6 +888,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -909,6 +915,10 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/godror/godror v0.49.4 h1:8kKWKoR17nPX7u10hr4GwD4u10hzTZED9ihdkuzRrKI= +github.com/godror/godror v0.49.4/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8= +github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw= +github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= @@ -1172,6 +1182,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4= github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= +github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= +github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -1671,6 +1683,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 3b37560004..4de64b402b 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/goccy/go-yaml" + _ "github.com/godror/godror" // OCI driver + _ "github.com/sijms/go-ora/v2" // Pure Go driver + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" - _ "github.com/sijms/go-ora/v2" "go.opentelemetry.io/otel/trace" ) @@ -32,7 +34,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources return nil, err } - // Validate that we have one of: tns_alias, connection_string, or host+service_name + // Validate that we have one of: tnsAlias, connectionString, or host+service_name if err := actual.validate(); err != nil { return nil, fmt.Errorf("invalid Oracle configuration: %w", err) } @@ -43,21 +45,24 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` - ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename) - TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora - Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias - Port int `yaml:"port,omitempty"` // Explicit port support - ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias + ConnectionString string `yaml:"connectionString,omitempty"` + TnsAlias string `yaml:"tnsAlias,omitempty"` + TnsAdmin string `yaml:"tnsAdmin,omitempty"` + Host string `yaml:"host,omitempty"` + Port int `yaml:"port,omitempty"` + ServiceName string `yaml:"serviceName,omitempty"` User string `yaml:"user" validate:"required"` Password string `yaml:"password" validate:"required"` - TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable + UseOCI bool `yaml:"useOCI,omitempty"` + WalletLocation string `yaml:"walletLocation,omitempty"` } -// validate ensures we have one of: tns_alias, connection_string, or host+service_name func (c Config) validate() error { + hasTnsAdmin := strings.TrimSpace(c.TnsAdmin) != "" hasTnsAlias := strings.TrimSpace(c.TnsAlias) != "" hasConnStr := strings.TrimSpace(c.ConnectionString) != "" hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != "" + hasWallet := strings.TrimSpace(c.WalletLocation) != "" connectionMethods := 0 if hasTnsAlias { @@ -78,6 +83,14 @@ func (c Config) validate() error { return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'") } + if hasTnsAdmin && !c.UseOCI { + return fmt.Errorf("`tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead") + } + + if hasWallet && c.UseOCI { + return fmt.Errorf("when using an OCI driver, use `tnsAdmin` to specify credentials file location instead") + } + return nil } @@ -132,7 +145,8 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi panic(err) } - // Set TNS_ADMIN environment variable if specified in config. + hasWallet := strings.TrimSpace(config.WalletLocation) != "" + if config.TnsAdmin != "" { originalTnsAdmin := os.Getenv("TNS_ADMIN") os.Setenv("TNS_ADMIN", config.TnsAdmin) @@ -147,28 +161,49 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi }() } - var serverString string + var connectStringBase string if config.TnsAlias != "" { - // Use TNS alias - serverString = strings.TrimSpace(config.TnsAlias) + connectStringBase = strings.TrimSpace(config.TnsAlias) } else if config.ConnectionString != "" { - // Use provided connection string directly (hostname[:port]/servicename format) - serverString = strings.TrimSpace(config.ConnectionString) + connectStringBase = strings.TrimSpace(config.ConnectionString) } else { - // Build connection string from host and service_name if config.Port > 0 { - serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) + connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) } else { - serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) + connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) } } - connStr := fmt.Sprintf("oracle://%s:%s@%s", - config.User, config.Password, serverString) + var driverName string + var finalConnStr string - db, err := sql.Open("oracle", connStr) + if config.UseOCI { + // Use godror driver (requires OCI) + driverName = "godror" + finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase)) + } else { + // Use go-ora driver (pure Go) + driverName = "oracle" + + user := config.User + password := config.Password + + if hasWallet { + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s", + user, password, connectStringBase, config.WalletLocation) + } else { + // Standard go-ora connection + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s", + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase)) + } + } + + db, err := sql.Open(driverName, finalConnStr) if err != nil { - return nil, fmt.Errorf("unable to open Oracle connection: %w", err) + return nil, fmt.Errorf("unable to open Oracle connection with driver %s: %w", driverName, err) } return db, nil diff --git a/internal/sources/oracle/oracle_test.go b/internal/sources/oracle/oracle_test.go new file mode 100644 index 0000000000..3d8f4c7ba5 --- /dev/null +++ b/internal/sources/oracle/oracle_test.go @@ -0,0 +1,200 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracle_test + +import ( + "strings" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/oracle" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "connection string and useOCI=true", + in: ` + sources: + my-oracle-cs: + kind: oracle + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-cs": oracle.Config{ + Name: "my-oracle-cs", + Kind: oracle.SourceKind, + ConnectionString: "my-host:1521/XEPDB1", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + { + desc: "host/port/serviceName and default useOCI=false", + in: ` + sources: + my-oracle-host: + kind: oracle + host: my-host + port: 1521 + serviceName: ORCLPDB + user: my_user + password: my_pass + `, + want: server.SourceConfigs{ + "my-oracle-host": oracle.Config{ + Name: "my-oracle-host", + Kind: oracle.SourceKind, + Host: "my-host", + Port: 1521, + ServiceName: "ORCLPDB", + User: "my_user", + Password: "my_pass", + UseOCI: false, + }, + }, + }, + { + desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true", + in: ` + sources: + my-oracle-tns-oci: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-tns-oci": oracle.Config{ + Name: "my-oracle-tns-oci", + Kind: oracle.SourceKind, + TnsAlias: "FINANCE_DB", + TnsAdmin: "/opt/oracle/network/admin", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources)) + } + }) + } +} + +func TestFailParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + password: my_pass + extraField: value + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ", + }, + { + desc: "missing required password field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + { + desc: "missing connection method fields (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'", + }, + { + desc: "multiple connection methods provided (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'", + }, + { + desc: "fail on tnsAdmin with useOCI=false", + in: ` + sources: + my-oracle-fail: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: false + `, + err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := strings.ReplaceAll(err.Error(), "\r", "") + + if errStr != tc.err { + t.Fatalf("unexpected error:\ngot:\n%q\nwant:\n%q\n", errStr, tc.err) + } + }) + } +} diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index 23d3a9b3de..1dd708f471 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -110,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error getting logger: %s", err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sqlParam)) + logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam) results, err := t.Pool.QueryContext(ctx, sqlParam) if err != nil { diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go new file mode 100644 index 0000000000..834d3d6981 --- /dev/null +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go @@ -0,0 +1,82 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracleexecutesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" +) + +func TestParseFromYamlOracleExecuteSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with auth", + in: ` + tools: + run_adhoc_query: + kind: oracle-execute-sql + source: my-oracle-instance + description: Executes arbitrary SQL statements like INSERT or UPDATE. + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "run_adhoc_query": oracleexecutesql.Config{ + Name: "run_adhoc_query", + Kind: "oracle-execute-sql", + Source: "my-oracle-instance", + Description: "Executes arbitrary SQL statements like INSERT or UPDATE.", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example without authRequired", + in: ` + tools: + run_simple_update: + kind: oracle-execute-sql + source: db-dev + description: Runs a simple update operation. + `, + want: server.ToolConfigs{ + "run_simple_update": oracleexecutesql.Config{ + Name: "run_simple_update", + Kind: "oracle-execute-sql", + Source: "db-dev", + Description: "Runs a simple update operation.", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/oracle/oraclesql/oraclesql_test.go b/internal/tools/oracle/oraclesql/oraclesql_test.go new file mode 100644 index 0000000000..2ba0a7321c --- /dev/null +++ b/internal/tools/oracle/oraclesql/oraclesql_test.go @@ -0,0 +1,85 @@ +// Copyright © 2025, Oracle and/or its affiliates. +package oraclesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" +) + +func TestParseFromYamlOracleSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with statement and auth", + in: ` + tools: + get_user_by_id: + kind: oracle-sql + source: my-oracle-instance + description: Retrieves user details by ID. + statement: "SELECT id, name, email FROM users WHERE id = :1" + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "get_user_by_id": oraclesql.Config{ + Name: "get_user_by_id", + Kind: "oracle-sql", + Source: "my-oracle-instance", + Description: "Retrieves user details by ID.", + Statement: "SELECT id, name, email FROM users WHERE id = :1", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example with parameters and template parameters", + in: ` + tools: + get_orders: + kind: oracle-sql + source: db-prod + description: Gets orders for a customer with optional filtering. + statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status" + `, + want: server.ToolConfigs{ + "get_orders": oraclesql.Config{ + Name: "get_orders", + Kind: "oracle-sql", + Source: "db-prod", + Description: "Gets orders for a customer with optional filtering.", + Statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/tests/oracle/oracle_integration_test.go b/tests/oracle/oracle_integration_test.go index 04f272a1b8..0021679e9e 100644 --- a/tests/oracle/oracle_integration_test.go +++ b/tests/oracle/oracle_integration_test.go @@ -43,6 +43,7 @@ func getOracleVars(t *testing.T) map[string]any { return map[string]any{ "kind": OracleSourceKind, "connectionString": OracleConnStr, + "useOCI": true, "user": OracleUser, "password": OraclePass, } @@ -50,9 +51,11 @@ func getOracleVars(t *testing.T) map[string]any { // Copied over from oracle.go func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) { - fullConnStr := fmt.Sprintf("oracle://%s:%s@%s", user, pass, connStr) + // Build the full Oracle connection string for godror driver + fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + user, pass, connStr) - db, err := sql.Open("oracle", fullConnStr) + db, err := sql.Open("godror", fullConnStr) if err != nil { return nil, fmt.Errorf("unable to open Oracle connection: %w", err) } @@ -116,13 +119,15 @@ func TestOracleSimpleToolEndpoints(t *testing.T) { // Get configs for tests select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ORA-00900: invalid SQL statement\n error occur at position: 0"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` // Run tests tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, + tests.DisableOptionalNullParamTest(), + tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"), tests.DisableArrayTest(), ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)