mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(oracle): Switch Oracle driver from godror to go-ora (#1685)
This avoids CGO cross compilation. - [godror](https://github.com/godror/godror) - depends on Oracle Client Libraries & [requires C compiler](https://github.com/godror/godror?tab=readme-ov-file#build-time-requirements) - [ go-ora](https://github.com/sijms/go-ora) - pure Go driver
This commit is contained in:
@@ -705,8 +705,8 @@ steps:
|
||||
cassandra
|
||||
|
||||
- id: "oracle"
|
||||
name: ghcr.io/oracle/oraclelinux8-instantclient:21
|
||||
waitFor: ["install-dependencies"]
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
@@ -719,15 +719,10 @@ steps:
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
# Install the C compiler and Oracle SDK headers needed for cgo
|
||||
dnf install -y gcc oracle-instantclient-devel
|
||||
|
||||
# Install Go
|
||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
|
||||
tar -C /usr/local -xzf go.tar.gz
|
||||
export PATH="/usr/local/go/bin:$$PATH"
|
||||
|
||||
go test ./tests/oracle
|
||||
.ci/test_with_coverage.sh \
|
||||
"Oracle" \
|
||||
oracle \
|
||||
oracle
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
|
||||
5
go.mod
5
go.mod
@@ -28,7 +28,6 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/godror/godror v0.49.3
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -38,6 +37,7 @@ require (
|
||||
github.com/nakagami/firebirdsql v0.9.15
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/sijms/go-ora/v2 v2.9.0
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/thlib/go-timezone-local v0.0.7
|
||||
github.com/trinodb/trino-go-client v0.329.0
|
||||
@@ -86,7 +86,6 @@ require (
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/apache/arrow/go/v15 v15.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -103,13 +102,11 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.1 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/godror/knownpb v0.3.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -681,10 +681,6 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
|
||||
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
|
||||
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
|
||||
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
@@ -882,8 +878,6 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb
|
||||
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
||||
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -911,10 +905,6 @@ github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/godror/godror v0.49.3 h1:84CPEu1p3qPvpN7PTHv8NDept+t+d+AoO/7WjYVsFNc=
|
||||
github.com/godror/godror v0.49.3/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -1178,8 +1168,6 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k=
|
||||
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
|
||||
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
|
||||
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
|
||||
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
@@ -1236,6 +1224,8 @@ github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sijms/go-ora/v2 v2.9.0 h1:+iQbUeTeCOFMb5BsOMgUhV8KWyrv9yjKpcK4x7+MFrg=
|
||||
github.com/sijms/go-ora/v2 v2.9.0/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
@@ -1673,8 +1663,6 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
_ "github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
_ "github.com/sijms/go-ora/v2"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -125,13 +125,12 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||
defer span.End()
|
||||
|
||||
var connectString string
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set TNS_ADMIN environment variable if specified in config
|
||||
// Set TNS_ADMIN environment variable if specified in config.
|
||||
if config.TnsAdmin != "" {
|
||||
originalTnsAdmin := os.Getenv("TNS_ADMIN")
|
||||
os.Setenv("TNS_ADMIN", config.TnsAdmin)
|
||||
@@ -146,27 +145,26 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
|
||||
}()
|
||||
}
|
||||
|
||||
// Determine the connection string to use (priority order)
|
||||
var serverString string
|
||||
if config.TnsAlias != "" {
|
||||
// Use TNS alias - godror will resolve from tnsnames.ora
|
||||
connectString = strings.TrimSpace(config.TnsAlias)
|
||||
// Use TNS alias
|
||||
serverString = strings.TrimSpace(config.TnsAlias)
|
||||
} else if config.ConnectionString != "" {
|
||||
// Use provided connection string directly (hostname[:port]/servicename format)
|
||||
connectString = strings.TrimSpace(config.ConnectionString)
|
||||
serverString = strings.TrimSpace(config.ConnectionString)
|
||||
} else {
|
||||
// Build connection string from host and service_name
|
||||
if config.Port > 0 {
|
||||
connectString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
|
||||
serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
|
||||
} else {
|
||||
connectString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
|
||||
serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the full Oracle connection string for godror driver
|
||||
connStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
config.User, config.Password, connectString)
|
||||
connStr := fmt.Sprintf("oracle://%s:%s@%s",
|
||||
config.User, config.Password, serverString)
|
||||
|
||||
db, err := sql.Open("godror", connStr)
|
||||
db, err := sql.Open("oracle", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracleexecutesql
|
||||
@@ -20,11 +7,9 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -118,7 +103,7 @@ type Tool struct {
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
sqlParam, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
|
||||
}
|
||||
@@ -128,73 +113,103 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
results, err := t.Pool.QueryContext(ctx, sqlParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, _ := results.Columns()
|
||||
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
|
||||
// We proceed, and results.Err() will catch actual query execution errors.
|
||||
// 'out' will remain nil if cols is empty or err is not nil here.
|
||||
cols, _ := results.Columns()
|
||||
|
||||
// Get Column types
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("query execution error: %w", err)
|
||||
}
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
// Create slice to hold values
|
||||
values := make([]any, len(cols))
|
||||
valuePtrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
for i, colType := range colTypes {
|
||||
// Based on the database type, we prepare a pointer to a Go type.
|
||||
switch strings.ToUpper(colType.DatabaseTypeName()) {
|
||||
case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE":
|
||||
if _, scale, ok := colType.DecimalSize(); ok && scale == 0 {
|
||||
// Scale is 0, treat as an integer.
|
||||
values[i] = new(sql.NullInt64)
|
||||
} else {
|
||||
// Scale is non-zero or unknown, treat as a float.
|
||||
values[i] = new(sql.NullFloat64)
|
||||
}
|
||||
case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE":
|
||||
values[i] = new(sql.NullTime)
|
||||
case "JSON":
|
||||
values[i] = new(sql.RawBytes)
|
||||
default:
|
||||
values[i] = new(sql.NullString)
|
||||
}
|
||||
}
|
||||
|
||||
// Scan the values
|
||||
if err := results.Scan(valuePtrs...); err != nil {
|
||||
if err := results.Scan(values...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create result map
|
||||
vMap := make(map[string]any)
|
||||
for i, col := range cols {
|
||||
val := values[i]
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[col] = string(val.([]byte))
|
||||
case "NUMBER":
|
||||
s := string(val.(godror.Number))
|
||||
if strings.Contains(s, ".") {
|
||||
vMap[col], err = strconv.ParseFloat(s, 64)
|
||||
receiver := values[i]
|
||||
|
||||
// Dereference the pointer and check for validity (not NULL).
|
||||
switch v := receiver.(type) {
|
||||
case *sql.NullInt64:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Int64
|
||||
} else {
|
||||
vMap[col], err = strconv.ParseInt(s, 10, 64)
|
||||
vMap[col] = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
|
||||
case *sql.NullFloat64:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Float64
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.NullString:
|
||||
if v.Valid {
|
||||
vMap[col] = v.String
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.NullTime:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Time
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.RawBytes:
|
||||
if *v != nil {
|
||||
var unmarshaledData any
|
||||
if err := json.Unmarshal(*v, &unmarshaledData); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data for column %s", col)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
default:
|
||||
vMap[col] = val
|
||||
return nil, fmt.Errorf("unexpected receiver type: %T", v)
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
// Check for errors from iterating over rows or from the query execution itself.
|
||||
// results.Close() is handled by defer.
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,4 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oraclesql
|
||||
|
||||
@@ -19,11 +7,9 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/godror/godror"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -143,9 +129,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
// NO PARAMETER CONVERSION - godror supports :1, :2, :3 natively
|
||||
// Execute Oracle query with original statement
|
||||
|
||||
rows, err := t.DB.QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
@@ -162,45 +145,72 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
// Create slice to hold values
|
||||
values := make([]any, len(cols))
|
||||
valuePtrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
for i, colType := range colTypes {
|
||||
switch strings.ToUpper(colType.DatabaseTypeName()) {
|
||||
case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE":
|
||||
if _, scale, ok := colType.DecimalSize(); ok && scale == 0 {
|
||||
// Scale is 0, treat it as an integer.
|
||||
values[i] = new(sql.NullInt64)
|
||||
} else {
|
||||
// Scale is non-zero or unknown, treat
|
||||
// it as a float.
|
||||
values[i] = new(sql.NullFloat64)
|
||||
}
|
||||
case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE":
|
||||
values[i] = new(sql.NullTime)
|
||||
case "JSON":
|
||||
values[i] = new(sql.RawBytes)
|
||||
default:
|
||||
values[i] = new(sql.NullString)
|
||||
}
|
||||
}
|
||||
|
||||
// Scan the values
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create result map
|
||||
vMap := make(map[string]any)
|
||||
for i, col := range cols {
|
||||
val := values[i]
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[col] = string(val.([]byte))
|
||||
case "NUMBER":
|
||||
s := string(val.(godror.Number))
|
||||
if strings.Contains(s, ".") {
|
||||
vMap[col], err = strconv.ParseFloat(s, 64)
|
||||
receiver := values[i]
|
||||
|
||||
switch v := receiver.(type) {
|
||||
case *sql.NullInt64:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Int64
|
||||
} else {
|
||||
vMap[col], err = strconv.ParseInt(s, 10, 64)
|
||||
vMap[col] = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to convert NUMBER data '%s' for column %s: %w", s, col, err)
|
||||
case *sql.NullFloat64:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Float64
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.NullString:
|
||||
if v.Valid {
|
||||
vMap[col] = v.String
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.NullTime:
|
||||
if v.Valid {
|
||||
vMap[col] = v.Time
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
case *sql.RawBytes:
|
||||
if *v != nil {
|
||||
var unmarshaledData any
|
||||
if err := json.Unmarshal(*v, &unmarshaledData); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data for column %s", col)
|
||||
}
|
||||
vMap[col] = unmarshaledData
|
||||
} else {
|
||||
vMap[col] = nil
|
||||
}
|
||||
default:
|
||||
vMap[col] = val
|
||||
return nil, fmt.Errorf("unexpected receiver type: %T", v)
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
|
||||
@@ -50,11 +50,9 @@ func getOracleVars(t *testing.T) map[string]any {
|
||||
|
||||
// Copied over from oracle.go
|
||||
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
|
||||
// Build the full Oracle connection string for godror driver
|
||||
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
user, pass, connStr)
|
||||
fullConnStr := fmt.Sprintf("oracle://%s:%s@%s", user, pass, connStr)
|
||||
|
||||
db, err := sql.Open("godror", fullConnStr)
|
||||
db, err := sql.Open("oracle", fullConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
@@ -118,15 +116,13 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
|
||||
|
||||
// Get configs for tests
|
||||
select1Want := "[{\"1\":1}]"
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ORA-00900: invalid SQL statement\n error occur at position: 0"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
|
||||
tests.DisableArrayTest(),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
|
||||
Reference in New Issue
Block a user