feat: add user agent to cloud databases (#244)

Add user agent to cloud databases that provides us anonymized data
request count, number of users, number of projects, and other
environment settings.

User agent is using the format: `genai-toolbox/$version+metadata`
This commit is contained in:
Yuan
2025-01-29 17:19:52 -08:00
committed by GitHub
parent 8152a98b7a
commit 8452f8eb44
8 changed files with 70 additions and 50 deletions

View File

@@ -29,6 +29,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
@@ -58,6 +59,9 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
parentCtx, span := instrumentation.Tracer.Start(context.Background(), "toolbox/server/init")
defer span.End()
userAgent := fmt.Sprintf("genai-toolbox/%s", cfg.Version)
parentCtx = context.WithValue(parentCtx, util.UserAgentKey, userAgent)
// set up http serving
r := chi.NewRouter()
r.Use(middleware.Recoverer)

View File

@@ -22,6 +22,7 @@ import (
"cloud.google.com/go/alloydbconn"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace"
)
@@ -83,15 +84,17 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}
func getDialOpts(ipType string) ([]alloydbconn.DialOption, error) {
func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) {
opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)}
switch strings.ToLower(ipType) {
case "private":
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPrivateIP()))
case "public":
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPublicIP()))
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
return opts, nil
}
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
@@ -107,11 +110,12 @@ func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name,
}
// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := getOpts(ipType, userAgent)
if err != nil {
return nil, err
}
d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...))
d, err := alloydbconn.NewDialer(context.Background(), opts...)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}

View File

@@ -19,11 +19,10 @@ import (
"database/sql"
"fmt"
"slices"
"strings"
"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)
@@ -91,17 +90,6 @@ func (s *Source) MSSQLDB() *sql.DB {
return s.Db
}
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
@@ -111,14 +99,15 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
// Get dial options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}
// Register sql server driver
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", opts...)
if err != nil {
return nil, err
}

View File

@@ -19,11 +19,10 @@ import (
"database/sql"
"fmt"
"slices"
"strings"
"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)
@@ -83,30 +82,20 @@ func (s *Source) MySQLPool() *sql.DB {
return s.Pool
}
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}
if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
_, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
_, err = mysql.RegisterDriver("cloudsql-mysql", opts...)
if err != nil {
return nil, fmt.Errorf("unable to register driver: %w", err)
}

View File

@@ -18,10 +18,10 @@ import (
"context"
"fmt"
"net"
"strings"
"cloud.google.com/go/cloudsqlconn"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace"
)
@@ -82,17 +82,6 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
@@ -106,11 +95,12 @@ func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name
}
// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}
d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...))
d, err := cloudsqlconn.NewDialer(context.Background(), opts...)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}

View File

@@ -20,6 +20,7 @@ import (
"cloud.google.com/go/spanner"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)
@@ -94,7 +95,8 @@ func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project,
}
// Create spanner client
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
userAgent := ctx.Value(util.UserAgentKey).(string)
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig, UserAgent: userAgent})
if err != nil {
return nil, fmt.Errorf("unable to create new client: %w", err)
}

37
internal/sources/util.go Normal file
View File

@@ -0,0 +1,37 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sources
import (
"fmt"
"strings"
"cloud.google.com/go/cloudsqlconn"
)
// GetCloudSQLDialOpts retrieve dial options with the right ip type and user agent for cloud sql
// databases.
func GetCloudSQLOpts(ipType, userAgent string) ([]cloudsqlconn.Option, error) {
opts := []cloudsqlconn.Option{cloudsqlconn.WithUserAgent(userAgent)}
switch strings.ToLower(ipType) {
case "private":
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPrivateIP()))
case "public":
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPublicIP()))
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
return opts, nil
}

View File

@@ -34,3 +34,8 @@ func (d *DelayedUnmarshaler) UnmarshalYAML(unmarshal func(interface{}) error) er
func (d *DelayedUnmarshaler) Unmarshal(v interface{}) error {
return d.unmarshal(v)
}
type contextKey string
// UserAgentKey is the key used to store userAgent within context
const UserAgentKey contextKey = "userAgent"