mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-05-02 03:00:36 -04:00
feat(sources/oracle): add Oracle Source and Tool (#1456)
## Description --- > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/488 --------- Co-authored-by: duwenxin <duwenxin@google.com> Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
This commit is contained in:
@@ -703,7 +703,32 @@ steps:
|
||||
"Cassandra" \
|
||||
cassandra \
|
||||
cassandra
|
||||
|
||||
|
||||
- id: "oracle"
|
||||
name: ghcr.io/oracle/oraclelinux8-instantclient:21
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
- "ORACLE_SERVER_NAME=$_ORACLE_SERVER_NAME"
|
||||
secretEnv: ["CLIENT_ID", "ORACLE_USER", "ORACLE_PASS", "ORACLE_HOST"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
# Install the C compiler and Oracle SDK headers needed for cgo
|
||||
dnf install -y gcc oracle-instantclient-devel
|
||||
|
||||
# Install Go
|
||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
|
||||
tar -C /usr/local -xzf go.tar.gz
|
||||
export PATH="/usr/local/go/bin:$$PATH"
|
||||
|
||||
go test ./tests/oracle
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -794,6 +819,12 @@ availableSecrets:
|
||||
env: CASSANDRA_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest
|
||||
env: CASSANDRA_HOST
|
||||
- versionName: projects/$PROJECT_ID/secrets/oracle_user/versions/latest
|
||||
env: ORACLE_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/oracle_pass/versions/latest
|
||||
env: ORACLE_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/oracle_host/versions/latest
|
||||
env: ORACLE_HOST
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
@@ -845,3 +876,4 @@ substitutions:
|
||||
_YUGABYTEDB_DATABASE: "yugabyte"
|
||||
_YUGABYTEDB_PORT: "5433"
|
||||
_YUGABYTEDB_LOADBALANCE: "false"
|
||||
_ORACLE_SERVER_NAME: "FREEPDB1"
|
||||
@@ -139,6 +139,8 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions"
|
||||
@@ -184,6 +186,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
|
||||
104
docs/en/resources/sources/oracle.md
Normal file
104
docs/en/resources/sources/oracle.md
Normal file
@@ -0,0 +1,104 @@
|
||||
---
|
||||
title: "Oracle"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Oracle Database is a widely-used relational database management system.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[Oracle Database][oracle-docs] is a multi-model database management system produced and marketed by Oracle Corporation. It is commonly used for running online transaction processing (OLTP), data warehousing (DW), and mixed (OLTP & DW) database workloads.
|
||||
|
||||
[oracle-docs]: https://www.oracle.com/database/
|
||||
|
||||
## Available Tools
|
||||
|
||||
- [`oracle-sql`](../tools/oracle/oracle-sql.md)
|
||||
Execute pre-defined prepared SQL queries in Oracle.
|
||||
|
||||
- [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md)
|
||||
Run parameterized SQL queries in Oracle.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
This source uses standard authentication. You will need to [create an Oracle user][oracle-users] to log in to the database with the necessary permissions.
|
||||
|
||||
[oracle-users]:
|
||||
https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html
|
||||
|
||||
### Oracle Instant Client (OIC)
|
||||
|
||||
The underlying database driver requires the [Oracle Instant Client][oracle-ic] libraries to connect to the database. These libraries must be installed on the machine where the application is running.
|
||||
|
||||
After installing the client, ensure the library path is correctly configured for your operating system (e.g., by setting the `LD_LIBRARY_PATH` environment variable on Linux or adding the directory to the `PATH` on Windows) so the application can find the necessary files at runtime.
|
||||
|
||||
[oracle-ic]: https://www.oracle.com/database/technologies/instant-client/downloads.html
|
||||
|
||||
## Connection Methods
|
||||
|
||||
You can configure the connection to your Oracle database using one of the following three methods. **You should only use one method** in your source configuration.
|
||||
|
||||
### Basic Connection (Host/Port/Service Name)
|
||||
|
||||
This is the most straightforward method, where you provide the connection details as separate fields:
|
||||
|
||||
- `host`: The IP address or hostname of the database server.
|
||||
- `port`: The port number the Oracle listener is running on (typically 1521).
|
||||
- `serviceName`: The service name for the database instance you wish to connect to.
|
||||
|
||||
### Connection String
|
||||
|
||||
As an alternative, you can provide all the connection details in a single `connectionString`. This is a convenient way to consolidate the connection information. The typical format is `hostname:port/servicename`.
|
||||
|
||||
### TNS Alias
|
||||
|
||||
For environments that use a `tnsnames.ora` configuration file, you can connect using a TNS (Transparent Network Substrate) alias.
|
||||
|
||||
- `tnsAlias`: Specify the alias name defined in your `tnsnames.ora` file.
|
||||
- `tnsAdmin` (Optional): If your configuration file is not in a standard location, you can use this field to provide the path to the directory containing it. This setting will override the `TNS_ADMIN` environment variable.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-oracle-source:
|
||||
kind: oracle
|
||||
# --- Choose one connection method ---
|
||||
# 1. Host, Port, and Service Name
|
||||
host: 127.0.0.1
|
||||
port: 1521
|
||||
serviceName: XEPDB1
|
||||
|
||||
# 2. Direct Connection String
|
||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||
|
||||
# 3. TNS Alias (requires tnsnames.ora)
|
||||
tnsAlias: "MY_DB_ALIAS"
|
||||
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
|
||||
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "oracle". |
|
||||
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
|
||||
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
|
||||
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
|
||||
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
|
||||
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
||||
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
||||
7
docs/en/resources/tools/oracle/_index.md
Normal file
7
docs/en/resources/tools/oracle/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Oracle"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Oracle Sources.
|
||||
---
|
||||
31
docs/en/resources/tools/oracle/oracle-execute-sql.md
Normal file
31
docs/en/resources/tools/oracle/oracle-execute-sql.md
Normal file
@@ -0,0 +1,31 @@
|
||||
---
|
||||
title: "oracle-execute-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
An "oracle-execute-sql" tool executes a SQL statement against an Oracle database.
|
||||
aliases:
|
||||
- /resources/tools/oracle-execute-sql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
An `oracle-execute-sql` tool executes a SQL statement against an Oracle
|
||||
database. It's compatible with the following source:
|
||||
|
||||
- [oracle](../../sources/oracle.md)
|
||||
|
||||
`oracle-execute-sql` takes one input parameter `sql` and runs the sql
|
||||
statement against the `source`.
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with
|
||||
> human-in-the-loop and shouldn't be used for production agents.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
execute_sql_tool:
|
||||
kind: oracle-execute-sql
|
||||
source: my-oracle-instance
|
||||
description: Use this tool to execute sql statement.
|
||||
57
docs/en/resources/tools/oracle/oracle-sql.md
Normal file
57
docs/en/resources/tools/oracle/oracle-sql.md
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: "oracle-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
An "oracle-sql" tool executes a pre-defined SQL statement against an Oracle database.
|
||||
aliases:
|
||||
- /resources/tools/oracle-sql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
An `oracle-sql` tool executes a pre-defined SQL statement against an
|
||||
Oracle database. It's compatible with the following source:
|
||||
|
||||
- [oracle](../../sources/oracle.md)
|
||||
|
||||
The specified SQL statement is executed using [prepared statements][oracle-stmt]
|
||||
for security and performance. It expects parameter placeholders in the SQL query
|
||||
to be in the native Oracle format (e.g., `:1`, `:2`).
|
||||
|
||||
[oracle-stmt]: https://docs.oracle.com/javase/tutorial/jdbc/basics/prepared.html
|
||||
|
||||
## Example
|
||||
|
||||
> **Note:** This tool uses parameterized queries to prevent SQL injections.
|
||||
> Query parameters can be used as substitutes for arbitrary expressions.
|
||||
> Parameters cannot be used as substitutes for identifiers, column names, table
|
||||
> names, or other parts of the query.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_flights_by_number:
|
||||
kind: oracle-sql
|
||||
source: my-oracle-instance
|
||||
statement: |
|
||||
SELECT * FROM flights
|
||||
WHERE airline = :1
|
||||
AND flight_number = :2
|
||||
FETCH FIRST 10 ROWS ONLY
|
||||
description: |
|
||||
Use this tool to get information for a specific flight.
|
||||
Takes an airline code and flight number and returns info on the flight.
|
||||
Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number.
|
||||
Example:
|
||||
{{
|
||||
"airline": "CY",
|
||||
"flight_number": "888",
|
||||
}}
|
||||
parameters:
|
||||
- name: airline
|
||||
type: string
|
||||
description: Airline unique 2 letter identifier
|
||||
- name: flight_number
|
||||
type: string
|
||||
description: 1 to 4 digit number
|
||||
```
|
||||
4
go.mod
4
go.mod
@@ -28,6 +28,7 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/godror/godror v0.49.3
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -85,6 +86,7 @@ require (
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/apache/arrow/go/v15 v15.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -101,11 +103,13 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.1 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/godror/knownpb v0.3.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
|
||||
14
go.sum
14
go.sum
@@ -681,6 +681,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
|
||||
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
|
||||
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
|
||||
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
@@ -878,6 +882,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb
|
||||
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
||||
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -905,6 +911,10 @@ github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/godror/godror v0.49.3 h1:84CPEu1p3qPvpN7PTHv8NDept+t+d+AoO/7WjYVsFNc=
|
||||
github.com/godror/godror v0.49.3/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -1168,6 +1178,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k=
|
||||
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
|
||||
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
|
||||
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
|
||||
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
@@ -1661,6 +1673,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
175
internal/sources/oracle/oracle.go
Normal file
175
internal/sources/oracle/oracle.go
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
_ "github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "oracle"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate that we have one of: tns_alias, connection_string, or host+service_name
|
||||
if err := actual.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid Oracle configuration: %w", err)
|
||||
}
|
||||
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename)
|
||||
TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora
|
||||
Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias
|
||||
Port int `yaml:"port,omitempty"` // Explicit port support
|
||||
ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable
|
||||
}
|
||||
|
||||
// validate ensures we have one of: tns_alias, connection_string, or host+service_name
|
||||
func (c Config) validate() error {
|
||||
hasTnsAlias := strings.TrimSpace(c.TnsAlias) != ""
|
||||
hasConnStr := strings.TrimSpace(c.ConnectionString) != ""
|
||||
hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != ""
|
||||
|
||||
connectionMethods := 0
|
||||
if hasTnsAlias {
|
||||
connectionMethods++
|
||||
}
|
||||
if hasConnStr {
|
||||
connectionMethods++
|
||||
}
|
||||
if hasHostService {
|
||||
connectionMethods++
|
||||
}
|
||||
|
||||
if connectionMethods == 0 {
|
||||
return fmt.Errorf("must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'")
|
||||
}
|
||||
|
||||
if connectionMethods > 1 {
|
||||
return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
db, err := initOracleConnection(ctx, tracer, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to Oracle successfully: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
DB: db,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) OracleDB() *sql.DB {
|
||||
return s.DB
|
||||
}
|
||||
|
||||
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||
defer span.End()
|
||||
|
||||
var connectString string
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set TNS_ADMIN environment variable if specified in config
|
||||
if config.TnsAdmin != "" {
|
||||
originalTnsAdmin := os.Getenv("TNS_ADMIN")
|
||||
os.Setenv("TNS_ADMIN", config.TnsAdmin)
|
||||
logger.DebugContext(ctx, fmt.Sprintf("Setting TNS_ADMIN to: %s\n", config.TnsAdmin))
|
||||
// Restore original TNS_ADMIN after connection
|
||||
defer func() {
|
||||
if originalTnsAdmin != "" {
|
||||
os.Setenv("TNS_ADMIN", originalTnsAdmin)
|
||||
} else {
|
||||
os.Unsetenv("TNS_ADMIN")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Determine the connection string to use (priority order)
|
||||
if config.TnsAlias != "" {
|
||||
// Use TNS alias - godror will resolve from tnsnames.ora
|
||||
connectString = strings.TrimSpace(config.TnsAlias)
|
||||
} else if config.ConnectionString != "" {
|
||||
// Use provided connection string directly (hostname[:port]/servicename format)
|
||||
connectString = strings.TrimSpace(config.ConnectionString)
|
||||
} else {
|
||||
// Build connection string from host and service_name
|
||||
if config.Port > 0 {
|
||||
connectString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
|
||||
} else {
|
||||
connectString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the full Oracle connection string for godror driver
|
||||
connStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
config.User, config.Password, connectString)
|
||||
|
||||
db, err := sql.Open("godror", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
223
internal/tools/oracle/oracleexecutesql/oracleexecutesql.go
Normal file
223
internal/tools/oracle/oracleexecutesql/oracleexecutesql.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracleexecutesql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
const kind string = "oracle-execute-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
OracleDB() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &oracle.Source{}
|
||||
|
||||
var compatibleSources = [...]string{oracle.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := tools.NewStringParameter("sql", "The SQL to execute.")
|
||||
parameters := tools.Parameters{sqlParameter}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.OracleDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
|
||||
}
|
||||
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, _ := results.Columns()
|
||||
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
|
||||
// We proceed, and results.Err() will catch actual query execution errors.
|
||||
// 'out' will remain nil if cols is empty or err is not nil here.
|
||||
|
||||
// Get Column types
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
// Create slice to hold values
|
||||
values := make([]any, len(cols))
|
||||
valuePtrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// Scan the values
|
||||
if err := results.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create result map
|
||||
vMap := make(map[string]any)
|
||||
for i, col := range cols {
|
||||
val := values[i]
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[col] = string(val.([]byte))
|
||||
case "NUMBER":
|
||||
s := string(val.(godror.Number))
|
||||
if strings.Contains(s, ".") {
|
||||
vMap[col], err = strconv.ParseFloat(s, 64)
|
||||
} else {
|
||||
vMap[col], err = strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
|
||||
}
|
||||
default:
|
||||
vMap[col] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
// Check for errors from iterating over rows or from the query execution itself.
|
||||
// results.Close() is handled by defer.
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
234
internal/tools/oracle/oraclesql/oraclesql.go
Normal file
234
internal/tools/oracle/oraclesql/oraclesql.go
Normal file
@@ -0,0 +1,234 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package oraclesql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "oracle-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
OracleDB() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &oracle.Source{}
|
||||
|
||||
var compatibleSources = [...]string{oracle.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error processing parameters: %w", err)
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
DB: s.OracleDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
DB *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
|
||||
for i, p := range sliceParams {
|
||||
fmt.Printf("[%d]=%T ", i, p)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
// NO PARAMETER CONVERSION - godror supports :1, :2, :3 natively
|
||||
// Execute Oracle query with original statement
|
||||
|
||||
rows, err := t.DB.QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, _ := rows.Columns()
|
||||
|
||||
// Get Column types
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
// Create slice to hold values
|
||||
values := make([]any, len(cols))
|
||||
valuePtrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// Scan the values
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create result map
|
||||
vMap := make(map[string]any)
|
||||
for i, col := range cols {
|
||||
val := values[i]
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[col] = string(val.([]byte))
|
||||
case "NUMBER":
|
||||
s := string(val.(godror.Number))
|
||||
if strings.Contains(s, ".") {
|
||||
vMap[col], err = strconv.ParseFloat(s, 64)
|
||||
} else {
|
||||
vMap[col], err = strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
|
||||
}
|
||||
default:
|
||||
vMap[col] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
243
tests/oracle/oracle_integration_test.go
Normal file
243
tests/oracle/oracle_integration_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
OracleSourceKind = "oracle"
|
||||
OracleToolKind = "oracle-sql"
|
||||
OracleHost = os.Getenv("ORACLE_HOST")
|
||||
OracleUser = os.Getenv("ORACLE_USER")
|
||||
OraclePass = os.Getenv("ORACLE_PASS")
|
||||
OracleServerName = os.Getenv("ORACLE_SERVER_NAME")
|
||||
OracleConnStr = fmt.Sprintf(
|
||||
"%s:%s/%s", OracleHost, "1521", OracleServerName)
|
||||
)
|
||||
|
||||
func getOracleVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case OracleHost:
|
||||
t.Fatal("'ORACLE_HOST not set")
|
||||
case OracleUser:
|
||||
t.Fatal("'ORACLE_USER' not set")
|
||||
case OraclePass:
|
||||
t.Fatal("'ORACLE_PASS' not set")
|
||||
case OracleServerName:
|
||||
t.Fatal("'ORACLE_SERVER_NAME' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": OracleSourceKind,
|
||||
"connectionString": OracleConnStr,
|
||||
"user": OracleUser,
|
||||
"password": OraclePass,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from oracle.go
|
||||
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
|
||||
// Build the full Oracle connection string for godror driver
|
||||
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
user, pass, connStr)
|
||||
|
||||
db, err := sql.Open("godror", fullConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to ping Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestOracleSimpleToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getOracleVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
db, err := initOracleConnection(ctx, OracleUser, OraclePass, OracleConnStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Oracle connection pool: %s", err)
|
||||
}
|
||||
|
||||
dropAllUserTables(t, ctx, db)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, OracleToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "oracle-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, OracleToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Get configs for tests
|
||||
select1Want := "[{\"1\":1}]"
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
|
||||
tests.DisableArrayTest(),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
}
|
||||
|
||||
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||
err := pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = pool.QueryContext(ctx, createStatement)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = pool.QueryContext(ctx, insertStatement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName))
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||
// Use GENERATED AS IDENTITY for auto-incrementing primary keys.
|
||||
// VARCHAR2 is the standard string type in Oracle.
|
||||
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255))`, tableName)
|
||||
|
||||
// MODIFIED: Use a PL/SQL block for multiple inserts
|
||||
insertStatement := fmt.Sprintf(`
|
||||
BEGIN
|
||||
INSERT INTO %s ("name") VALUES (:1);
|
||||
INSERT INTO %s ("name") VALUES (:2);
|
||||
INSERT INTO %s ("name") VALUES (:3);
|
||||
INSERT INTO %s ("name") VALUES (:4);
|
||||
END;`, tableName, tableName, tableName, tableName)
|
||||
|
||||
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" = :1 OR "name" = :2`, tableName)
|
||||
idParamStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" = :1`, tableName)
|
||||
nameParamStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "name" = :1`, tableName)
|
||||
|
||||
// Oracle's equivalent for array parameters is using the 'MEMBER OF' operator
|
||||
// with a collection type defined in the database schema.
|
||||
arrayToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE "id" MEMBER OF :1 AND "name" MEMBER OF :2`, tableName)
|
||||
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
|
||||
return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// getOracleAuthToolInfo returns statements and params for my-auth-tool for Oracle SQL
|
||||
func getOracleAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf(`CREATE TABLE %s ("id" NUMBER GENERATED AS IDENTITY PRIMARY KEY, "name" VARCHAR2(255), "email" VARCHAR2(255))`, tableName)
|
||||
|
||||
// MODIFIED: Use a PL/SQL block for multiple inserts
|
||||
insertStatement := fmt.Sprintf(`
|
||||
BEGIN
|
||||
INSERT INTO %s ("name", "email") VALUES (:1, :2);
|
||||
INSERT INTO %s ("name", "email") VALUES (:3, :4);
|
||||
END;`, tableName, tableName)
|
||||
|
||||
toolStatement := fmt.Sprintf(`SELECT "name" FROM %s WHERE "email" = :1`, tableName)
|
||||
|
||||
params := []any{"Alice", tests.ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
|
||||
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
}
|
||||
|
||||
// dropAllUserTables finds and drops all tables owned by the current user.
|
||||
func dropAllUserTables(t *testing.T, ctx context.Context, db *sql.DB) {
|
||||
// Query for only the tables we know are created by this test suite.
|
||||
const query = `
|
||||
SELECT table_name FROM user_tables
|
||||
WHERE table_name LIKE 'param_table_%'
|
||||
OR table_name LIKE 'auth_table_%'
|
||||
OR table_name LIKE 'template_param_table_%'`
|
||||
|
||||
rows, err := db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query for user tables: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tablesToDrop []string
|
||||
for rows.Next() {
|
||||
var tableName string
|
||||
if err := rows.Scan(&tableName); err != nil {
|
||||
t.Fatalf("failed to scan table name: %v", err)
|
||||
}
|
||||
tablesToDrop = append(tablesToDrop, tableName)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Fatalf("error iterating over tables: %v", err)
|
||||
}
|
||||
|
||||
for _, tableName := range tablesToDrop {
|
||||
_, err := db.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s CASCADE CONSTRAINTS", tableName))
|
||||
if err != nil {
|
||||
t.Logf("failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func TestPostgres(t *testing.T) {
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupPostgresTables(t, ctx, pool);
|
||||
tests.CleanupPostgresTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
Reference in New Issue
Block a user