From 8452f8eb4457dcb0e360a9d9ae5b6e14e78806b1 Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:19:52 -0800 Subject: [PATCH] feat: add user agent to cloud databases (#244) Add user agent to cloud databases that provides us anonymized data request count, number of users, number of projects, and other environment settings. User agent is using the format: `genai-toolbox/$version+metadata` --- internal/server/server.go | 4 ++ internal/sources/alloydbpg/alloydb_pg.go | 14 ++++--- .../sources/cloudsqlmssql/cloud_sql_mssql.go | 19 ++-------- .../sources/cloudsqlmysql/cloud_sql_mysql.go | 19 ++-------- internal/sources/cloudsqlpg/cloud_sql_pg.go | 18 ++------- internal/sources/spanner/spanner.go | 4 +- internal/sources/util.go | 37 +++++++++++++++++++ internal/util/util.go | 5 +++ 8 files changed, 70 insertions(+), 50 deletions(-) create mode 100644 internal/sources/util.go diff --git a/internal/server/server.go b/internal/server/server.go index f0a4b60216..060dc1266d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -29,6 +29,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) @@ -58,6 +59,9 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er parentCtx, span := instrumentation.Tracer.Start(context.Background(), "toolbox/server/init") defer span.End() + userAgent := fmt.Sprintf("genai-toolbox/%s", cfg.Version) + parentCtx = context.WithValue(parentCtx, util.UserAgentKey, userAgent) + // set up http serving r := chi.NewRouter() r.Use(middleware.Recoverer) diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index 64cb9fa0f5..00f4c9e37c 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -22,6 +22,7 @@ import ( "cloud.google.com/go/alloydbconn" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -83,15 +84,17 @@ func (s *Source) PostgresPool() *pgxpool.Pool { return s.Pool } -func getDialOpts(ipType string) ([]alloydbconn.DialOption, error) { +func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) { + opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)} switch strings.ToLower(ipType) { case "private": - return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil + opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPrivateIP())) case "public": - return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil + opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPublicIP())) default: return nil, fmt.Errorf("invalid ipType %s", ipType) } + return opts, nil } func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) { @@ -107,11 +110,12 @@ func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, } // Create a new dialer with options - dialOpts, err := getDialOpts(ipType) + userAgent := ctx.Value(util.UserAgentKey).(string) + opts, err := getOpts(ipType, userAgent) if err != nil { return nil, err } - d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...)) + d, err := alloydbconn.NewDialer(context.Background(), opts...) if err != nil { return nil, fmt.Errorf("unable to parse connection uri: %w", err) } diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index 31a1e255ee..fdf60c29df 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -19,11 +19,10 @@ import ( "database/sql" "fmt" "slices" - "strings" - "cloud.google.com/go/cloudsqlconn" "cloud.google.com/go/cloudsqlconn/sqlserver/mssql" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" ) @@ -91,17 +90,6 @@ func (s *Source) MSSQLDB() *sql.DB { return s.Db } -func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) { - switch strings.ToLower(ipType) { - case "private": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil - case "public": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil - default: - return nil, fmt.Errorf("invalid ipType %s", ipType) - } -} - func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) @@ -111,14 +99,15 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance) // Get dial options - dialOpts, err := getDialOpts(ipType) + userAgent := ctx.Value(util.UserAgentKey).(string) + opts, err := sources.GetCloudSQLOpts(ipType, userAgent) if err != nil { return nil, err } // Register sql server driver if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") { - _, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + _, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", opts...) if err != nil { return nil, err } diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index e74a6d14eb..6b610c00d5 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -19,11 +19,10 @@ import ( "database/sql" "fmt" "slices" - "strings" - "cloud.google.com/go/cloudsqlconn" "cloud.google.com/go/cloudsqlconn/mysql/mysql" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" ) @@ -83,30 +82,20 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } -func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) { - switch strings.ToLower(ipType) { - case "private": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil - case "public": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil - default: - return nil, fmt.Errorf("invalid ipType %s", ipType) - } -} - func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() // Create a new dialer with options - dialOpts, err := getDialOpts(ipType) + userAgent := ctx.Value(util.UserAgentKey).(string) + opts, err := sources.GetCloudSQLOpts(ipType, userAgent) if err != nil { return nil, err } if !slices.Contains(sql.Drivers(), "cloudsql-mysql") { - _, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + _, err = mysql.RegisterDriver("cloudsql-mysql", opts...) if err != nil { return nil, fmt.Errorf("unable to register driver: %w", err) } diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index aeae9079a5..c8659241ac 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -18,10 +18,10 @@ import ( "context" "fmt" "net" - "strings" "cloud.google.com/go/cloudsqlconn" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -82,17 +82,6 @@ func (s *Source) PostgresPool() *pgxpool.Pool { return s.Pool } -func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) { - switch strings.ToLower(ipType) { - case "private": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil - case "public": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil - default: - return nil, fmt.Errorf("invalid ipType %s", ipType) - } -} - func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) @@ -106,11 +95,12 @@ func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name } // Create a new dialer with options - dialOpts, err := getDialOpts(ipType) + userAgent := ctx.Value(util.UserAgentKey).(string) + opts, err := sources.GetCloudSQLOpts(ipType, userAgent) if err != nil { return nil, err } - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + d, err := cloudsqlconn.NewDialer(context.Background(), opts...) if err != nil { return nil, fmt.Errorf("unable to parse connection uri: %w", err) } diff --git a/internal/sources/spanner/spanner.go b/internal/sources/spanner/spanner.go index de8cfa1666..d8276f3bf2 100644 --- a/internal/sources/spanner/spanner.go +++ b/internal/sources/spanner/spanner.go @@ -20,6 +20,7 @@ import ( "cloud.google.com/go/spanner" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" ) @@ -94,7 +95,8 @@ func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, } // Create spanner client - client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig}) + userAgent := ctx.Value(util.UserAgentKey).(string) + client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig, UserAgent: userAgent}) if err != nil { return nil, fmt.Errorf("unable to create new client: %w", err) } diff --git a/internal/sources/util.go b/internal/sources/util.go new file mode 100644 index 0000000000..cbb8de5ebd --- /dev/null +++ b/internal/sources/util.go @@ -0,0 +1,37 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sources + +import ( + "fmt" + "strings" + + "cloud.google.com/go/cloudsqlconn" +) + +// GetCloudSQLDialOpts retrieve dial options with the right ip type and user agent for cloud sql +// databases. +func GetCloudSQLOpts(ipType, userAgent string) ([]cloudsqlconn.Option, error) { + opts := []cloudsqlconn.Option{cloudsqlconn.WithUserAgent(userAgent)} + switch strings.ToLower(ipType) { + case "private": + opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPrivateIP())) + case "public": + opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPublicIP())) + default: + return nil, fmt.Errorf("invalid ipType %s", ipType) + } + return opts, nil +} diff --git a/internal/util/util.go b/internal/util/util.go index 48e54234a1..255869dc35 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -34,3 +34,8 @@ func (d *DelayedUnmarshaler) UnmarshalYAML(unmarshal func(interface{}) error) er func (d *DelayedUnmarshaler) Unmarshal(v interface{}) error { return d.unmarshal(v) } + +type contextKey string + +// UserAgentKey is the key used to store userAgent within context +const UserAgentKey contextKey = "userAgent"