feat(sources/trino): add ssl verification options and fix docs example (#2155)

## Description

Adds options such as disableSslVerification, sslCert and sslCertPath to
trino source. Also fixes trino-sql docs on params

## PR Checklist

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #1910

---------
This commit is contained in:
gRedHeadphone
2026-01-08 06:49:23 +05:30
committed by GitHub
parent b706b5bc68
commit 4a4cf1e712
4 changed files with 117 additions and 38 deletions

View File

@@ -50,16 +50,19 @@ instead of hardcoding your secrets into the configuration file.
## Reference
| **field** | **type** | **required** | **description** |
|-----------------|:--------:|:------------:|------------------------------------------------------------------------------|
| kind | string | true | Must be "trino". |
| host | string | true | Trino coordinator hostname (e.g. "trino.example.com") |
| port | string | true | Trino coordinator port (e.g. "8080", "8443") |
| user | string | false | Username for authentication (e.g. "analyst"). Optional for anonymous access. |
| password | string | false | Password for basic authentication |
| catalog | string | true | Default catalog to use for queries (e.g. "hive") |
| schema | string | true | Default schema to use for queries (e.g. "default") |
| queryTimeout | string | false | Query timeout duration (e.g. "30m", "1h") |
| accessToken | string | false | JWT access token for authentication |
| kerberosEnabled | boolean | false | Enable Kerberos authentication (default: false) |
| sslEnabled | boolean | false | Enable SSL/TLS (default: false) |
| **field** | **type** | **required** | **description** |
| ---------------------- | :------: | :----------: | ---------------------------------------------------------------------------- |
| kind | string | true | Must be "trino". |
| host | string | true | Trino coordinator hostname (e.g. "trino.example.com") |
| port | string | true | Trino coordinator port (e.g. "8080", "8443") |
| user | string | false | Username for authentication (e.g. "analyst"). Optional for anonymous access. |
| password | string | false | Password for basic authentication |
| catalog | string | true | Default catalog to use for queries (e.g. "hive") |
| schema | string | true | Default schema to use for queries (e.g. "default") |
| queryTimeout | string | false | Query timeout duration (e.g. "30m", "1h") |
| accessToken | string | false | JWT access token for authentication |
| kerberosEnabled | boolean | false | Enable Kerberos authentication (default: false) |
| sslEnabled | boolean | false | Enable SSL/TLS (default: false) |
| disableSslVerification | boolean | false | Skip SSL/TLS certificate verification (default: false) |
| sslCertPath | string | false | Path to a custom SSL/TLS certificate file |
| sslCert | string | false | Custom SSL/TLS certificate content |

View File

@@ -16,11 +16,7 @@ database. It's compatible with any of the following sources:
- [trino](../../sources/trino.md)
The specified SQL statement is executed as a [prepared statement][trino-prepare],
and specified parameters will be inserted according to their position: e.g. `$1`
will be the first parameter specified, `$2` will be the second parameter, and so
on. If template parameters are included, they will be resolved before execution
of the prepared statement.
The specified SQL statement is executed as a [prepared statement][trino-prepare], and expects parameters in the SQL query to be in the form of placeholders `?`.
[trino-prepare]: https://trino.io/docs/current/sql/prepare.html
@@ -38,8 +34,8 @@ tools:
source: my-trino-instance
statement: |
SELECT * FROM hive.sales.orders
WHERE region = $1
AND order_date >= DATE($2)
WHERE region = ?
AND order_date >= DATE(?)
LIMIT 10
description: |
Use this tool to get information for orders in a specific region.

View File

@@ -16,14 +16,17 @@ package trino
import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"net/http"
"net/url"
"time"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
_ "github.com/trinodb/trino-go-client/trino"
"github.com/googleapis/genai-toolbox/internal/util"
trinogo "github.com/trinodb/trino-go-client/trino"
"go.opentelemetry.io/otel/trace"
)
@@ -47,18 +50,21 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
Catalog string `yaml:"catalog" validate:"required"`
Schema string `yaml:"schema" validate:"required"`
QueryTimeout string `yaml:"queryTimeout"`
AccessToken string `yaml:"accessToken"`
KerberosEnabled bool `yaml:"kerberosEnabled"`
SSLEnabled bool `yaml:"sslEnabled"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
Catalog string `yaml:"catalog" validate:"required"`
Schema string `yaml:"schema" validate:"required"`
QueryTimeout string `yaml:"queryTimeout"`
AccessToken string `yaml:"accessToken"`
KerberosEnabled bool `yaml:"kerberosEnabled"`
SSLEnabled bool `yaml:"sslEnabled"`
SSLCertPath string `yaml:"sslCertPath"`
SSLCert string `yaml:"sslCert"`
DisableSslVerification bool `yaml:"disableSslVerification"`
}
func (r Config) SourceConfigKind() string {
@@ -66,7 +72,7 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initTrinoConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Catalog, r.Schema, r.QueryTimeout, r.AccessToken, r.KerberosEnabled, r.SSLEnabled)
pool, err := initTrinoConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Catalog, r.Schema, r.QueryTimeout, r.AccessToken, r.KerberosEnabled, r.SSLEnabled, r.SSLCertPath, r.SSLCert, r.DisableSslVerification)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
@@ -152,17 +158,35 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
return out, nil
}
func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (*sql.DB, error) {
func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool, sslCertPath, sslCert string, disableSslVerification bool) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Build Trino DSN
dsn, err := buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken, kerberosEnabled, sslEnabled)
dsn, err := buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken, kerberosEnabled, sslEnabled, sslCertPath, sslCert)
if err != nil {
return nil, fmt.Errorf("failed to build DSN: %w", err)
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
if disableSslVerification {
logger.WarnContext(ctx, "SSL verification is disabled for trino source %s. This is an insecure setting and should not be used in production.\n", name)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
clientName := fmt.Sprintf("insecure_trino_client_%s", name)
if err := trinogo.RegisterCustomClient(clientName, client); err != nil {
return nil, fmt.Errorf("failed to register custom client: %w", err)
}
dsn = fmt.Sprintf("%s&custom_client=%s", dsn, clientName)
}
db, err := sql.Open("trino", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open connection: %w", err)
@@ -176,7 +200,7 @@ func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, hos
return db, nil
}
func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (string, error) {
func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool, sslCertPath, sslCert string) (string, error) {
// Build query parameters
query := url.Values{}
query.Set("catalog", catalog)
@@ -190,6 +214,12 @@ func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, ac
if kerberosEnabled {
query.Set("KerberosEnabled", "true")
}
if sslCertPath != "" {
query.Set("sslCertPath", sslCertPath)
}
if sslCert != "" {
query.Set("sslCert", sslCert)
}
// Build URL
scheme := "http"

View File

@@ -36,6 +36,8 @@ func TestBuildTrinoDSN(t *testing.T) {
accessToken string
kerberosEnabled bool
sslEnabled bool
sslCertPath string
sslCert string
want string
wantErr bool
}{
@@ -49,6 +51,19 @@ func TestBuildTrinoDSN(t *testing.T) {
want: "http://testuser@localhost:8080?catalog=hive&schema=default",
wantErr: false,
},
{
name: "with SSL cert path and cert",
host: "localhost",
port: "8443",
user: "testuser",
catalog: "hive",
schema: "default",
sslEnabled: true,
sslCertPath: "/path/to/cert.pem",
sslCert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----\n",
want: "https://testuser@localhost:8443?catalog=hive&schema=default&sslCert=-----BEGIN+CERTIFICATE-----%0A...%0A-----END+CERTIFICATE-----%0A&sslCertPath=%2Fpath%2Fto%2Fcert.pem",
wantErr: false,
},
{
name: "with password",
host: "localhost",
@@ -117,7 +132,7 @@ func TestBuildTrinoDSN(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := buildTrinoDSN(tt.host, tt.port, tt.user, tt.password, tt.catalog, tt.schema, tt.queryTimeout, tt.accessToken, tt.kerberosEnabled, tt.sslEnabled)
got, err := buildTrinoDSN(tt.host, tt.port, tt.user, tt.password, tt.catalog, tt.schema, tt.queryTimeout, tt.accessToken, tt.kerberosEnabled, tt.sslEnabled, tt.sslCertPath, tt.sslCert)
if (err != nil) != tt.wantErr {
t.Errorf("buildTrinoDSN() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -215,6 +230,41 @@ func TestParseFromYamlTrino(t *testing.T) {
},
},
},
{
desc: "example with SSL cert path and cert",
in: `
sources:
my-trino-ssl-cert:
kind: trino
host: localhost
port: "8443"
user: testuser
catalog: hive
schema: default
sslEnabled: true
sslCertPath: /path/to/cert.pem
sslCert: |-
-----BEGIN CERTIFICATE-----
...
-----END CERTIFICATE-----
disableSslVerification: true
`,
want: server.SourceConfigs{
"my-trino-ssl-cert": Config{
Name: "my-trino-ssl-cert",
Kind: SourceKind,
Host: "localhost",
Port: "8443",
User: "testuser",
Catalog: "hive",
Schema: "default",
SSLEnabled: true,
SSLCertPath: "/path/to/cert.pem",
SSLCert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----",
DisableSslVerification: true,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {