mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(sources/trino): add ssl verification options and fix docs example (#2155)
## Description Adds options such as disableSslVerification, sslCert and sslCertPath to trino source. Also fixes trino-sql docs on params ## PR Checklist - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [x] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #1910 ---------
This commit is contained in:
@@ -50,16 +50,19 @@ instead of hardcoding your secrets into the configuration file.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------------|:--------:|:------------:|------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "trino". |
|
||||
| host | string | true | Trino coordinator hostname (e.g. "trino.example.com") |
|
||||
| port | string | true | Trino coordinator port (e.g. "8080", "8443") |
|
||||
| user | string | false | Username for authentication (e.g. "analyst"). Optional for anonymous access. |
|
||||
| password | string | false | Password for basic authentication |
|
||||
| catalog | string | true | Default catalog to use for queries (e.g. "hive") |
|
||||
| schema | string | true | Default schema to use for queries (e.g. "default") |
|
||||
| queryTimeout | string | false | Query timeout duration (e.g. "30m", "1h") |
|
||||
| accessToken | string | false | JWT access token for authentication |
|
||||
| kerberosEnabled | boolean | false | Enable Kerberos authentication (default: false) |
|
||||
| sslEnabled | boolean | false | Enable SSL/TLS (default: false) |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ---------------------- | :------: | :----------: | ---------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "trino". |
|
||||
| host | string | true | Trino coordinator hostname (e.g. "trino.example.com") |
|
||||
| port | string | true | Trino coordinator port (e.g. "8080", "8443") |
|
||||
| user | string | false | Username for authentication (e.g. "analyst"). Optional for anonymous access. |
|
||||
| password | string | false | Password for basic authentication |
|
||||
| catalog | string | true | Default catalog to use for queries (e.g. "hive") |
|
||||
| schema | string | true | Default schema to use for queries (e.g. "default") |
|
||||
| queryTimeout | string | false | Query timeout duration (e.g. "30m", "1h") |
|
||||
| accessToken | string | false | JWT access token for authentication |
|
||||
| kerberosEnabled | boolean | false | Enable Kerberos authentication (default: false) |
|
||||
| sslEnabled | boolean | false | Enable SSL/TLS (default: false) |
|
||||
| disableSslVerification | boolean | false | Skip SSL/TLS certificate verification (default: false) |
|
||||
| sslCertPath | string | false | Path to a custom SSL/TLS certificate file |
|
||||
| sslCert | string | false | Custom SSL/TLS certificate content |
|
||||
|
||||
@@ -16,11 +16,7 @@ database. It's compatible with any of the following sources:
|
||||
|
||||
- [trino](../../sources/trino.md)
|
||||
|
||||
The specified SQL statement is executed as a [prepared statement][trino-prepare],
|
||||
and specified parameters will be inserted according to their position: e.g. `$1`
|
||||
will be the first parameter specified, `$2` will be the second parameter, and so
|
||||
on. If template parameters are included, they will be resolved before execution
|
||||
of the prepared statement.
|
||||
The specified SQL statement is executed as a [prepared statement][trino-prepare], and expects parameters in the SQL query to be in the form of placeholders `?`.
|
||||
|
||||
[trino-prepare]: https://trino.io/docs/current/sql/prepare.html
|
||||
|
||||
@@ -38,8 +34,8 @@ tools:
|
||||
source: my-trino-instance
|
||||
statement: |
|
||||
SELECT * FROM hive.sales.orders
|
||||
WHERE region = $1
|
||||
AND order_date >= DATE($2)
|
||||
WHERE region = ?
|
||||
AND order_date >= DATE(?)
|
||||
LIMIT 10
|
||||
description: |
|
||||
Use this tool to get information for orders in a specific region.
|
||||
|
||||
@@ -16,14 +16,17 @@ package trino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
_ "github.com/trinodb/trino-go-client/trino"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
trinogo "github.com/trinodb/trino-go-client/trino"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -47,18 +50,21 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Catalog string `yaml:"catalog" validate:"required"`
|
||||
Schema string `yaml:"schema" validate:"required"`
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
AccessToken string `yaml:"accessToken"`
|
||||
KerberosEnabled bool `yaml:"kerberosEnabled"`
|
||||
SSLEnabled bool `yaml:"sslEnabled"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Catalog string `yaml:"catalog" validate:"required"`
|
||||
Schema string `yaml:"schema" validate:"required"`
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
AccessToken string `yaml:"accessToken"`
|
||||
KerberosEnabled bool `yaml:"kerberosEnabled"`
|
||||
SSLEnabled bool `yaml:"sslEnabled"`
|
||||
SSLCertPath string `yaml:"sslCertPath"`
|
||||
SSLCert string `yaml:"sslCert"`
|
||||
DisableSslVerification bool `yaml:"disableSslVerification"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -66,7 +72,7 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
pool, err := initTrinoConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Catalog, r.Schema, r.QueryTimeout, r.AccessToken, r.KerberosEnabled, r.SSLEnabled)
|
||||
pool, err := initTrinoConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Catalog, r.Schema, r.QueryTimeout, r.AccessToken, r.KerberosEnabled, r.SSLEnabled, r.SSLCertPath, r.SSLCert, r.DisableSslVerification)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
@@ -152,17 +158,35 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (*sql.DB, error) {
|
||||
func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool, sslCertPath, sslCert string, disableSslVerification bool) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Build Trino DSN
|
||||
dsn, err := buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken, kerberosEnabled, sslEnabled)
|
||||
dsn, err := buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken, kerberosEnabled, sslEnabled, sslCertPath, sslCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
if disableSslVerification {
|
||||
logger.WarnContext(ctx, "SSL verification is disabled for trino source %s. This is an insecure setting and should not be used in production.\n", name)
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
clientName := fmt.Sprintf("insecure_trino_client_%s", name)
|
||||
if err := trinogo.RegisterCustomClient(clientName, client); err != nil {
|
||||
return nil, fmt.Errorf("failed to register custom client: %w", err)
|
||||
}
|
||||
dsn = fmt.Sprintf("%s&custom_client=%s", dsn, clientName)
|
||||
}
|
||||
|
||||
db, err := sql.Open("trino", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open connection: %w", err)
|
||||
@@ -176,7 +200,7 @@ func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, hos
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (string, error) {
|
||||
func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool, sslCertPath, sslCert string) (string, error) {
|
||||
// Build query parameters
|
||||
query := url.Values{}
|
||||
query.Set("catalog", catalog)
|
||||
@@ -190,6 +214,12 @@ func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, ac
|
||||
if kerberosEnabled {
|
||||
query.Set("KerberosEnabled", "true")
|
||||
}
|
||||
if sslCertPath != "" {
|
||||
query.Set("sslCertPath", sslCertPath)
|
||||
}
|
||||
if sslCert != "" {
|
||||
query.Set("sslCert", sslCert)
|
||||
}
|
||||
|
||||
// Build URL
|
||||
scheme := "http"
|
||||
|
||||
@@ -36,6 +36,8 @@ func TestBuildTrinoDSN(t *testing.T) {
|
||||
accessToken string
|
||||
kerberosEnabled bool
|
||||
sslEnabled bool
|
||||
sslCertPath string
|
||||
sslCert string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -49,6 +51,19 @@ func TestBuildTrinoDSN(t *testing.T) {
|
||||
want: "http://testuser@localhost:8080?catalog=hive&schema=default",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with SSL cert path and cert",
|
||||
host: "localhost",
|
||||
port: "8443",
|
||||
user: "testuser",
|
||||
catalog: "hive",
|
||||
schema: "default",
|
||||
sslEnabled: true,
|
||||
sslCertPath: "/path/to/cert.pem",
|
||||
sslCert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----\n",
|
||||
want: "https://testuser@localhost:8443?catalog=hive&schema=default&sslCert=-----BEGIN+CERTIFICATE-----%0A...%0A-----END+CERTIFICATE-----%0A&sslCertPath=%2Fpath%2Fto%2Fcert.pem",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with password",
|
||||
host: "localhost",
|
||||
@@ -117,7 +132,7 @@ func TestBuildTrinoDSN(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := buildTrinoDSN(tt.host, tt.port, tt.user, tt.password, tt.catalog, tt.schema, tt.queryTimeout, tt.accessToken, tt.kerberosEnabled, tt.sslEnabled)
|
||||
got, err := buildTrinoDSN(tt.host, tt.port, tt.user, tt.password, tt.catalog, tt.schema, tt.queryTimeout, tt.accessToken, tt.kerberosEnabled, tt.sslEnabled, tt.sslCertPath, tt.sslCert)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("buildTrinoDSN() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -215,6 +230,41 @@ func TestParseFromYamlTrino(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "example with SSL cert path and cert",
|
||||
in: `
|
||||
sources:
|
||||
my-trino-ssl-cert:
|
||||
kind: trino
|
||||
host: localhost
|
||||
port: "8443"
|
||||
user: testuser
|
||||
catalog: hive
|
||||
schema: default
|
||||
sslEnabled: true
|
||||
sslCertPath: /path/to/cert.pem
|
||||
sslCert: |-
|
||||
-----BEGIN CERTIFICATE-----
|
||||
...
|
||||
-----END CERTIFICATE-----
|
||||
disableSslVerification: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-trino-ssl-cert": Config{
|
||||
Name: "my-trino-ssl-cert",
|
||||
Kind: SourceKind,
|
||||
Host: "localhost",
|
||||
Port: "8443",
|
||||
User: "testuser",
|
||||
Catalog: "hive",
|
||||
Schema: "default",
|
||||
SSLEnabled: true,
|
||||
SSLCertPath: "/path/to/cert.pem",
|
||||
SSLCert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----",
|
||||
DisableSslVerification: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user