feat(sources/oracle): add Oracle Source and Tool (#1456)

## Description

---
> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/488

---------

Co-authored-by: duwenxin <duwenxin@google.com>
Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
This commit is contained in:
Vijay Balebail
2025-10-09 20:29:01 -05:00
committed by GitHub
parent 98f7ee2e36
commit 3a19a50ff2
13 changed files with 1129 additions and 2 deletions

View File

@@ -703,7 +703,32 @@ steps:
"Cassandra" \
cassandra \
cassandra
- id: "oracle"
name: ghcr.io/oracle/oraclelinux8-instantclient:21
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
- "ORACLE_SERVER_NAME=$_ORACLE_SERVER_NAME"
secretEnv: ["CLIENT_ID", "ORACLE_USER", "ORACLE_PASS", "ORACLE_HOST"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
# Install the C compiler and Oracle SDK headers needed for cgo
dnf install -y gcc oracle-instantclient-devel
# Install Go
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
tar -C /usr/local -xzf go.tar.gz
export PATH="/usr/local/go/bin:$$PATH"
go test ./tests/oracle
availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
@@ -794,6 +819,12 @@ availableSecrets:
env: CASSANDRA_PASS
- versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest
env: CASSANDRA_HOST
- versionName: projects/$PROJECT_ID/secrets/oracle_user/versions/latest
env: ORACLE_USER
- versionName: projects/$PROJECT_ID/secrets/oracle_pass/versions/latest
env: ORACLE_PASS
- versionName: projects/$PROJECT_ID/secrets/oracle_host/versions/latest
env: ORACLE_HOST
options:
logging: CLOUD_LOGGING_ONLY
@@ -845,3 +876,4 @@ substitutions:
_YUGABYTEDB_DATABASE: "yugabyte"
_YUGABYTEDB_PORT: "5433"
_YUGABYTEDB_LOADBALANCE: "false"
_ORACLE_SERVER_NAME: "FREEPDB1"

View File

@@ -139,6 +139,8 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema"
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions"
@@ -184,6 +186,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
_ "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
_ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
_ "github.com/googleapis/genai-toolbox/internal/sources/oracle"
_ "github.com/googleapis/genai-toolbox/internal/sources/postgres"
_ "github.com/googleapis/genai-toolbox/internal/sources/redis"
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"

View File

@@ -0,0 +1,104 @@
---
title: "Oracle"
type: docs
weight: 1
description: >
Oracle Database is a widely-used relational database management system.
---
## About
[Oracle Database][oracle-docs] is a multi-model database management system produced and marketed by Oracle Corporation. It is commonly used for running online transaction processing (OLTP), data warehousing (DW), and mixed (OLTP & DW) database workloads.
[oracle-docs]: https://www.oracle.com/database/
## Available Tools
- [`oracle-sql`](../tools/oracle/oracle-sql.md)
Execute pre-defined prepared SQL queries in Oracle.
- [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md)
Run parameterized SQL queries in Oracle.
## Requirements
### Database User
This source uses standard authentication. You will need to [create an Oracle user][oracle-users] to log in to the database with the necessary permissions.
[oracle-users]:
https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html
### Oracle Instant Client (OIC)
The underlying database driver requires the [Oracle Instant Client][oracle-ic] libraries to connect to the database. These libraries must be installed on the machine where the application is running.
After installing the client, ensure the library path is correctly configured for your operating system (e.g., by setting the `LD_LIBRARY_PATH` environment variable on Linux or adding the directory to the `PATH` on Windows) so the application can find the necessary files at runtime.
[oracle-ic]: https://www.oracle.com/database/technologies/instant-client/downloads.html
## Connection Methods
You can configure the connection to your Oracle database using one of the following three methods. **You should only use one method** in your source configuration.
### Basic Connection (Host/Port/Service Name)
This is the most straightforward method, where you provide the connection details as separate fields:
- `host`: The IP address or hostname of the database server.
- `port`: The port number the Oracle listener is running on (typically 1521).
- `serviceName`: The service name for the database instance you wish to connect to.
### Connection String
As an alternative, you can provide all the connection details in a single `connectionString`. This is a convenient way to consolidate the connection information. The typical format is `hostname:port/servicename`.
### TNS Alias
For environments that use a `tnsnames.ora` configuration file, you can connect using a TNS (Transparent Network Substrate) alias.
- `tnsAlias`: Specify the alias name defined in your `tnsnames.ora` file.
- `tnsAdmin` (Optional): If your configuration file is not in a standard location, you can use this field to provide the path to the directory containing it. This setting will override the `TNS_ADMIN` environment variable.
## Example
```yaml
sources:
my-oracle-source:
kind: oracle
# --- Choose one connection method ---
# 1. Host, Port, and Service Name
host: 127.0.0.1
port: 1521
serviceName: XEPDB1
# 2. Direct Connection String
connectionString: "127.0.0.1:1521/XEPDB1"
# 3. TNS Alias (requires tnsnames.ora)
tnsAlias: "MY_DB_ALIAS"
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
user: ${USER_NAME}
password: ${PASSWORD}
```
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your secrets into the configuration file.
{{< /notice >}}
## Reference
| **field** | **type** | **required** | **description** |
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "oracle". |
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |

View File

@@ -0,0 +1,7 @@
---
title: "Oracle"
type: docs
weight: 1
description: >
Tools that work with Oracle Sources.
---

View File

@@ -0,0 +1,31 @@
---
title: "oracle-execute-sql"
type: docs
weight: 1
description: >
An "oracle-execute-sql" tool executes a SQL statement against an Oracle database.
aliases:
- /resources/tools/oracle-execute-sql
---
## About
An `oracle-execute-sql` tool executes a SQL statement against an Oracle
database. It's compatible with the following source:
- [oracle](../../sources/oracle.md)
`oracle-execute-sql` takes one input parameter `sql` and runs the sql
statement against the `source`.
> **Note:** This tool is intended for developer assistant workflows with
> human-in-the-loop and shouldn't be used for production agents.
## Example
```yaml
tools:
execute_sql_tool:
kind: oracle-execute-sql
source: my-oracle-instance
description: Use this tool to execute sql statement.

View File

@@ -0,0 +1,57 @@
---
title: "oracle-sql"
type: docs
weight: 1
description: >
An "oracle-sql" tool executes a pre-defined SQL statement against an Oracle database.
aliases:
- /resources/tools/oracle-sql
---
## About
An `oracle-sql` tool executes a pre-defined SQL statement against an
Oracle database. It's compatible with the following source:
- [oracle](../../sources/oracle.md)
The specified SQL statement is executed using [prepared statements][oracle-stmt]
for security and performance. It expects parameter placeholders in the SQL query
to be in the native Oracle format (e.g., `:1`, `:2`).
[oracle-stmt]: https://docs.oracle.com/javase/tutorial/jdbc/basics/prepared.html
## Example
> **Note:** This tool uses parameterized queries to prevent SQL injections.
> Query parameters can be used as substitutes for arbitrary expressions.
> Parameters cannot be used as substitutes for identifiers, column names, table
> names, or other parts of the query.
```yaml
tools:
search_flights_by_number:
kind: oracle-sql
source: my-oracle-instance
statement: |
SELECT * FROM flights
WHERE airline = :1
AND flight_number = :2
FETCH FIRST 10 ROWS ONLY
description: |
Use this tool to get information for a specific flight.
Takes an airline code and flight number and returns info on the flight.
Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number.
Example:
{{
"airline": "CY",
"flight_number": "888",
}}
parameters:
- name: airline
type: string
description: Airline unique 2 letter identifier
- name: flight_number
type: string
description: 1 to 4 digit number
```

4
go.mod
View File

@@ -28,6 +28,7 @@ require (
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-yaml v1.18.0
github.com/gocql/gocql v1.7.0
github.com/godror/godror v0.49.3
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
@@ -85,6 +86,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
github.com/PuerkitoBio/goquery v1.10.3 // indirect
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/apache/arrow/go/v15 v15.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -101,11 +103,13 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/go-jose/go-jose/v4 v4.1.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/godror/knownpb v0.3.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect

14
go.sum
View File

@@ -681,6 +681,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc=
github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710=
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA=
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
@@ -878,6 +882,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -905,6 +911,10 @@ github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/godror/godror v0.49.3 h1:84CPEu1p3qPvpN7PTHv8NDept+t+d+AoO/7WjYVsFNc=
github.com/godror/godror v0.49.3/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
@@ -1168,6 +1178,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4=
github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k=
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -1661,6 +1673,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -0,0 +1,175 @@
// Copyright © 2025, Oracle and/or its affiliates.
package oracle
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"github.com/goccy/go-yaml"
_ "github.com/godror/godror"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "oracle"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
// Validate that we have one of: tns_alias, connection_string, or host+service_name
if err := actual.validate(); err != nil {
return nil, fmt.Errorf("invalid Oracle configuration: %w", err)
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename)
TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora
Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias
Port int `yaml:"port,omitempty"` // Explicit port support
ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable
}
// validate ensures we have one of: tns_alias, connection_string, or host+service_name
func (c Config) validate() error {
hasTnsAlias := strings.TrimSpace(c.TnsAlias) != ""
hasConnStr := strings.TrimSpace(c.ConnectionString) != ""
hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != ""
connectionMethods := 0
if hasTnsAlias {
connectionMethods++
}
if hasConnStr {
connectionMethods++
}
if hasHostService {
connectionMethods++
}
if connectionMethods == 0 {
return fmt.Errorf("must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'")
}
if connectionMethods > 1 {
return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'")
}
return nil
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
db, err := initOracleConnection(ctx, tracer, r)
if err != nil {
return nil, fmt.Errorf("unable to create Oracle connection: %w", err)
}
err = db.PingContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to connect to Oracle successfully: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
DB: db,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
DB *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) OracleDB() *sql.DB {
return s.DB
}
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
defer span.End()
var connectString string
logger, err := util.LoggerFromContext(ctx)
if err != nil {
panic(err)
}
// Set TNS_ADMIN environment variable if specified in config
if config.TnsAdmin != "" {
originalTnsAdmin := os.Getenv("TNS_ADMIN")
os.Setenv("TNS_ADMIN", config.TnsAdmin)
logger.DebugContext(ctx, fmt.Sprintf("Setting TNS_ADMIN to: %s\n", config.TnsAdmin))
// Restore original TNS_ADMIN after connection
defer func() {
if originalTnsAdmin != "" {
os.Setenv("TNS_ADMIN", originalTnsAdmin)
} else {
os.Unsetenv("TNS_ADMIN")
}
}()
}
// Determine the connection string to use (priority order)
if config.TnsAlias != "" {
// Use TNS alias - godror will resolve from tnsnames.ora
connectString = strings.TrimSpace(config.TnsAlias)
} else if config.ConnectionString != "" {
// Use provided connection string directly (hostname[:port]/servicename format)
connectString = strings.TrimSpace(config.ConnectionString)
} else {
// Build connection string from host and service_name
if config.Port > 0 {
connectString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
} else {
connectString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
}
}
// Build the full Oracle connection string for godror driver
connStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
config.User, config.Password, connectString)
db, err := sql.Open("godror", connStr)
if err != nil {
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
}
return db, nil
}

View File

@@ -0,0 +1,223 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright © 2025, Oracle and/or its affiliates.
package oracleexecutesql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/godror/godror"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
)
const kind string = "oracle-execute-sql"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
OracleDB() *sql.DB
}
// validate compatible sources are still compatible
var _ compatibleSource = &oracle.Source{}
var compatibleSources = [...]string{oracle.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
sqlParameter := tools.NewStringParameter("sql", "The SQL to execute.")
parameters := tools.Parameters{sqlParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
Pool: s.OracleDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Pool *sql.DB
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
}
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
results, err := t.Pool.QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, _ := results.Columns()
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
// We proceed, and results.Err() will catch actual query execution errors.
// 'out' will remain nil if cols is empty or err is not nil here.
// Get Column types
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
// Create slice to hold values
values := make([]any, len(cols))
valuePtrs := make([]any, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
// Scan the values
if err := results.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("unable to scan row: %w", err)
}
// Create result map
vMap := make(map[string]any)
for i, col := range cols {
val := values[i]
switch colTypes[i].DatabaseTypeName() {
case "JSON":
// unmarshal JSON data before storing to prevent double marshaling
var unmarshaledData any
err := json.Unmarshal(val.([]byte), &unmarshaledData)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
}
vMap[col] = unmarshaledData
case "TEXT", "VARCHAR", "NVARCHAR":
vMap[col] = string(val.([]byte))
case "NUMBER":
s := string(val.(godror.Number))
if strings.Contains(s, ".") {
vMap[col], err = strconv.ParseFloat(s, 64)
} else {
vMap[col], err = strconv.ParseInt(s, 10, 64)
}
if err != nil {
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
}
default:
vMap[col] = val
}
}
out = append(out, vMap)
}
// Check for errors from iterating over rows or from the query execution itself.
// results.Close() is handled by defer.
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
}
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization() bool {
return false
}

View File

@@ -0,0 +1,234 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package oraclesql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/godror/godror"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const kind string = "oracle-sql"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
OracleDB() *sql.DB
}
// validate compatible sources are still compatible
var _ compatibleSource = &oracle.Source{}
var compatibleSources = [...]string{oracle.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, fmt.Errorf("error processing parameters: %w", err)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
DB: s.OracleDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
DB *sql.DB
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
newParams, err := tools.GetParams(t.Parameters, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
for i, p := range sliceParams {
fmt.Printf("[%d]=%T ", i, p)
}
fmt.Printf("\n")
// NO PARAMETER CONVERSION - godror supports :1, :2, :3 natively
// Execute Oracle query with original statement
rows, err := t.DB.QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
// Get Column types
colTypes, err := rows.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for rows.Next() {
// Create slice to hold values
values := make([]any, len(cols))
valuePtrs := make([]any, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
// Scan the values
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("unable to scan row: %w", err)
}
// Create result map
vMap := make(map[string]any)
for i, col := range cols {
val := values[i]
switch colTypes[i].DatabaseTypeName() {
case "JSON":
// unmarshal JSON data before storing to prevent double marshaling
var unmarshaledData any
err := json.Unmarshal(val.([]byte), &unmarshaledData)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
}
vMap[col] = unmarshaledData
case "TEXT", "VARCHAR", "NVARCHAR":
vMap[col] = string(val.([]byte))
case "NUMBER":
s := string(val.(godror.Number))
if strings.Contains(s, ".") {
vMap[col], err = strconv.ParseFloat(s, 64)
} else {
vMap[col], err = strconv.ParseInt(s, 10, 64)
}
if err != nil {
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
}
default:
vMap[col] = val
}
}
out = append(out, vMap)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
}
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization() bool {
return false
}

View File

@@ -0,0 +1,243 @@
// Copyright © 2025, Oracle and/or its affiliates.
package oracle
import (
"context"
"database/sql"
"fmt"
"os"
"regexp"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/tests"
)
var (
OracleSourceKind = "oracle"
OracleToolKind = "oracle-sql"
OracleHost = os.Getenv("ORACLE_HOST")
OracleUser = os.Getenv("ORACLE_USER")
OraclePass = os.Getenv("ORACLE_PASS")
OracleServerName = os.Getenv("ORACLE_SERVER_NAME")
OracleConnStr = fmt.Sprintf(
"%s:%s/%s", OracleHost, "1521", OracleServerName)
)
func getOracleVars(t *testing.T) map[string]any {
switch "" {
case OracleHost:
t.Fatal("'ORACLE_HOST not set")
case OracleUser:
t.Fatal("'ORACLE_USER' not set")
case OraclePass:
t.Fatal("'ORACLE_PASS' not set")
case OracleServerName:
t.Fatal("'ORACLE_SERVER_NAME' not set")
}
return map[string]any{
"kind": OracleSourceKind,
"connectionString": OracleConnStr,
"user": OracleUser,
"password": OraclePass,
}
}
// Copied over from oracle.go
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
// Build the full Oracle connection string for godror driver
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
user, pass, connStr)
db, err := sql.Open("godror", fullConnStr)
if err != nil {
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
}
err = db.PingContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to ping Oracle connection: %w", err)
}
return db, nil
}
func TestOracleSimpleToolEndpoints(t *testing.T) {
sourceConfig := getOracleVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
db, err := initOracleConnection(ctx, OracleUser, OraclePass, OracleConnStr)
if err != nil {
t.Fatalf("unable to create Oracle connection pool: %s", err)
}
dropAllUserTables(t, ctx, db)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, OracleToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "oracle-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, OracleToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Get configs for tests
select1Want := "[{\"1\":1}]"
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
// Run tests
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.DisableOptionalNullParamTest(),
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
tests.DisableArrayTest(),
)
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
err := pool.PingContext(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
}
// Create table
_, err = pool.QueryContext(ctx, createStatement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
_, err = pool.QueryContext(ctx, insertStatement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}
return func(t *testing.T) {
// tear down test
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}
}
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
// Use GENERATED AS IDENTITY for auto-incrementing primary keys.
// VARCHAR2 is the standard string type in Oracle.
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255))`, tableName)
// MODIFIED: Use a PL/SQL block for multiple inserts
insertStatement := fmt.Sprintf(`
BEGIN
INSERT INTO %s ("name") VALUES (:1);
INSERT INTO %s ("name") VALUES (:2);
INSERT INTO %s ("name") VALUES (:3);
INSERT INTO %s ("name") VALUES (:4);
END;`, tableName, tableName, tableName, tableName)
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" = :1 OR "name" = :2`, tableName)
idParamStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" = :1`, tableName)
nameParamStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "name" = :1`, tableName)
// Oracle's equivalent for array parameters is using the 'MEMBER OF' operator
// with a collection type defined in the database schema.
arrayToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" MEMBER OF :1 AND "name" MEMBER OF :2`, tableName)
params := []any{"Alice", "Jane", "Sid", nil}
return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params
}
// getOracleAuthToolInfo returns statements and params for my-auth-tool for Oracle SQL
func getOracleAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255), "email" VARCHAR2(255))`, tableName)
// MODIFIED: Use a PL/SQL block for multiple inserts
insertStatement := fmt.Sprintf(`
BEGIN
INSERT INTO %s ("name", "email") VALUES (:1, :2);
INSERT INTO %s ("name", "email") VALUES (:3, :4);
END;`, tableName, tableName)
toolStatement := fmt.Sprintf(`SELECT "name" FROM %s WHERE "email" = :1`, tableName)
params := []any{"Alice", tests.ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
return createStatement, insertStatement, toolStatement, params
}
// dropAllUserTables finds and drops all tables owned by the current user.
func dropAllUserTables(t *testing.T, ctx context.Context, db *sql.DB) {
// Query for only the tables we know are created by this test suite.
const query = `
SELECT table_name FROM user_tables
WHERE table_name LIKE 'param_table_%'
OR table_name LIKE 'auth_table_%'
OR table_name LIKE 'template_param_table_%'`
rows, err := db.QueryContext(ctx, query)
if err != nil {
t.Fatalf("failed to query for user tables: %v", err)
}
defer rows.Close()
var tablesToDrop []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
t.Fatalf("failed to scan table name: %v", err)
}
tablesToDrop = append(tablesToDrop, tableName)
}
if err := rows.Err(); err != nil {
t.Fatalf("error iterating over tables: %v", err)
}
for _, tableName := range tablesToDrop {
_, err := db.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s CASCADE CONSTRAINTS", tableName))
if err != nil {
t.Logf("failed to drop table %s: %v", tableName, err)
}
}
}

View File

@@ -138,7 +138,7 @@ func TestPostgres(t *testing.T) {
}
// cleanup test environment
tests.CleanupPostgresTables(t, ctx, pool);
tests.CleanupPostgresTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")