mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
fix(postgres,mssql,cloudsqlmssql)!: encode source connection url for sources (#727)
Have to encode special character in connection url. Only needed for `postgres`, `mssql`, `cloud-sql-mssql` sources. Fixes #717
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||
@@ -111,7 +112,13 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
|
||||
defer span.End()
|
||||
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
|
||||
url := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: ipAddress,
|
||||
RawQuery: query,
|
||||
}
|
||||
|
||||
// Get dial options
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
@@ -134,7 +141,7 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
|
||||
// Open database connection
|
||||
db, err := sql.Open(
|
||||
"cloudsql-sqlserver-driver",
|
||||
dsn,
|
||||
url.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -106,10 +107,17 @@ func initMssqlConnection(ctx context.Context, tracer trace.Tracer, name, host, p
|
||||
defer span.End()
|
||||
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
|
||||
query := url.Values{}
|
||||
query.Add("database", dbname)
|
||||
url := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
RawQuery: query.Encode(),
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlserver", dsn)
|
||||
db, err := sql.Open("sqlserver", url.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -96,9 +97,15 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
i := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, dbname)
|
||||
pool, err := pgxpool.New(ctx, i)
|
||||
url := &url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
Path: dbname,
|
||||
}
|
||||
pool, err := pgxpool.New(ctx, url.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
@@ -76,7 +77,13 @@ func getCloudSQLMssqlVars(t *testing.T) map[string]any {
|
||||
// Copied over from cloud_sql_mssql.go
|
||||
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
|
||||
url := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: ipAddress,
|
||||
RawQuery: query,
|
||||
}
|
||||
|
||||
// Get dial options
|
||||
dialOpts, err := tests.GetCloudSQLDialOpts(ipType)
|
||||
@@ -95,7 +102,7 @@ func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, u
|
||||
// Open database connection
|
||||
db, err := sql.Open(
|
||||
"cloudsql-sqlserver-driver",
|
||||
dsn,
|
||||
url.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -65,10 +66,17 @@ func getMsSQLVars(t *testing.T) map[string]any {
|
||||
// Copied over from mssql.go
|
||||
func initMssqlConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
|
||||
query := url.Values{}
|
||||
query.Add("database", dbname)
|
||||
url := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
RawQuery: query.Encode(),
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlserver", dsn)
|
||||
db, err := sql.Open("sqlserver", url.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -65,8 +66,13 @@ func getPostgresVars(t *testing.T) map[string]any {
|
||||
// Copied over from postgres.go
|
||||
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
i := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, dbname)
|
||||
pool, err := pgxpool.New(context.Background(), i)
|
||||
url := &url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
Path: dbname,
|
||||
}
|
||||
pool, err := pgxpool.New(context.Background(), url.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user