chore: dedup userAgentRoundTripper into util (#2198)

Dedup userAgentRoundTripper into util where userAgent related code are
placed.
This commit is contained in:
Yuan Teoh
2025-12-18 19:19:14 -08:00
committed by GitHub
parent f520b4ed8a
commit 8217d1424d
5 changed files with 34 additions and 116 deletions

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "alloydb-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}

View File

@@ -29,26 +29,6 @@ import (
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -140,10 +114,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = &userAgentRoundTripper{
userAgent: s.userAgent,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil

View File

@@ -29,26 +29,6 @@ import (
const SourceKind string = "cloud-monitoring"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -98,10 +75,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "cloud-sql-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}

View File

@@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-playground/validator/v10"
@@ -119,6 +120,30 @@ func UserAgentFromContext(ctx context.Context) (string, error) {
}
}
type UserAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func NewUserAgentRoundTripper(ua string, next http.RoundTripper) *UserAgentRoundTripper {
return &UserAgentRoundTripper{
userAgent: ua,
next: next,
}
}
func (rt *UserAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// create a deep copy of the request
newReq := req.Clone(req.Context())
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(newReq)
}
func NewStrictDecoder(v interface{}) (*yaml.Decoder, error) {
b, err := yaml.Marshal(v)
if err != nil {