fix(postgres,mssql,cloudsqlmssql)!: encode source connection url for sources (#727)

Have to encode special character in connection url. Only needed for
`postgres`, `mssql`, `cloud-sql-mssql` sources.

Fixes #717
This commit is contained in:
Yuan
2025-06-18 15:32:46 -07:00
committed by GitHub
parent f77c829271
commit 67964d939f
6 changed files with 55 additions and 12 deletions

View File

@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"slices"
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
@@ -111,7 +112,13 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
defer span.End()
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
url := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(user, pass),
Host: ipAddress,
RawQuery: query,
}
// Get dial options
userAgent, err := util.UserAgentFromContext(ctx)
@@ -134,7 +141,7 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
// Open database connection
db, err := sql.Open(
"cloudsql-sqlserver-driver",
dsn,
url.String(),
)
if err != nil {
return nil, err

View File

@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -106,10 +107,17 @@ func initMssqlConnection(ctx context.Context, tracer trace.Tracer, name, host, p
defer span.End()
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
query := url.Values{}
query.Add("database", dbname)
url := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
RawQuery: query.Encode(),
}
// Open database connection
db, err := sql.Open("sqlserver", dsn)
db, err := sql.Open("sqlserver", url.String())
if err != nil {
return nil, fmt.Errorf("sql.Open: %w", err)
}

View File

@@ -17,6 +17,7 @@ package postgres
import (
"context"
"fmt"
"net/url"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -96,9 +97,15 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
i := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, dbname)
pool, err := pgxpool.New(ctx, i)
url := &url.URL{
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
Path: dbname,
}
pool, err := pgxpool.New(ctx, url.String())
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}

View File

@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"os"
"regexp"
"slices"
@@ -76,7 +77,13 @@ func getCloudSQLMssqlVars(t *testing.T) map[string]any {
// Copied over from cloud_sql_mssql.go
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
url := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(user, pass),
Host: ipAddress,
RawQuery: query,
}
// Get dial options
dialOpts, err := tests.GetCloudSQLDialOpts(ipType)
@@ -95,7 +102,7 @@ func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, u
// Open database connection
db, err := sql.Open(
"cloudsql-sqlserver-driver",
dsn,
url.String(),
)
if err != nil {
return nil, err

View File

@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"os"
"regexp"
"strings"
@@ -65,10 +66,17 @@ func getMsSQLVars(t *testing.T) map[string]any {
// Copied over from mssql.go
func initMssqlConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
query := url.Values{}
query.Add("database", dbname)
url := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
RawQuery: query.Encode(),
}
// Open database connection
db, err := sql.Open("sqlserver", dsn)
db, err := sql.Open("sqlserver", url.String())
if err != nil {
return nil, fmt.Errorf("sql.Open: %w", err)
}

View File

@@ -17,6 +17,7 @@ package postgres
import (
"context"
"fmt"
"net/url"
"os"
"regexp"
"strings"
@@ -65,8 +66,13 @@ func getPostgresVars(t *testing.T) map[string]any {
// Copied over from postgres.go
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
i := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, dbname)
pool, err := pgxpool.New(context.Background(), i)
url := &url.URL{
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
Path: dbname,
}
pool, err := pgxpool.New(context.Background(), url.String())
if err != nil {
return nil, fmt.Errorf("Unable to create connection pool: %w", err)
}