mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add user agent to cloud databases (#244)
Add user agent to cloud databases that provides us anonymized data request count, number of users, number of projects, and other environment settings. User agent is using the format: `genai-toolbox/$version+metadata`
This commit is contained in:
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -58,6 +59,9 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
parentCtx, span := instrumentation.Tracer.Start(context.Background(), "toolbox/server/init")
|
||||
defer span.End()
|
||||
|
||||
userAgent := fmt.Sprintf("genai-toolbox/%s", cfg.Version)
|
||||
parentCtx = context.WithValue(parentCtx, util.UserAgentKey, userAgent)
|
||||
|
||||
// set up http serving
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"cloud.google.com/go/alloydbconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -83,15 +84,17 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]alloydbconn.DialOption, error) {
|
||||
func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) {
|
||||
opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)}
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
|
||||
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPrivateIP()))
|
||||
case "public":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
|
||||
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPublicIP()))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
@@ -107,11 +110,12 @@ func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name,
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
userAgent := ctx.Value(util.UserAgentKey).(string)
|
||||
opts, err := getOpts(ipType, userAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...))
|
||||
d, err := alloydbconn.NewDialer(context.Background(), opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
@@ -19,11 +19,10 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -91,17 +90,6 @@ func (s *Source) MSSQLDB() *sql.DB {
|
||||
return s.Db
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
@@ -111,14 +99,15 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
|
||||
// Get dial options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
userAgent := ctx.Value(util.UserAgentKey).(string)
|
||||
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register sql server driver
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
|
||||
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,11 +19,10 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -83,30 +82,20 @@ func (s *Source) MySQLPool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
userAgent := ctx.Value(util.UserAgentKey).(string)
|
||||
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
|
||||
_, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
_, err = mysql.RegisterDriver("cloudsql-mysql", opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to register driver: %w", err)
|
||||
}
|
||||
|
||||
@@ -18,10 +18,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -82,17 +82,6 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
@@ -106,11 +95,12 @@ func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
userAgent := ctx.Value(util.UserAgentKey).(string)
|
||||
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
d, err := cloudsqlconn.NewDialer(context.Background(), opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -94,7 +95,8 @@ func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project,
|
||||
}
|
||||
|
||||
// Create spanner client
|
||||
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
|
||||
userAgent := ctx.Value(util.UserAgentKey).(string)
|
||||
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig, UserAgent: userAgent})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create new client: %w", err)
|
||||
}
|
||||
|
||||
37
internal/sources/util.go
Normal file
37
internal/sources/util.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
)
|
||||
|
||||
// GetCloudSQLDialOpts retrieve dial options with the right ip type and user agent for cloud sql
|
||||
// databases.
|
||||
func GetCloudSQLOpts(ipType, userAgent string) ([]cloudsqlconn.Option, error) {
|
||||
opts := []cloudsqlconn.Option{cloudsqlconn.WithUserAgent(userAgent)}
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPrivateIP()))
|
||||
case "public":
|
||||
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPublicIP()))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
@@ -34,3 +34,8 @@ func (d *DelayedUnmarshaler) UnmarshalYAML(unmarshal func(interface{}) error) er
|
||||
func (d *DelayedUnmarshaler) Unmarshal(v interface{}) error {
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
// UserAgentKey is the key used to store userAgent within context
|
||||
const UserAgentKey contextKey = "userAgent"
|
||||
|
||||
Reference in New Issue
Block a user