mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
chore: dedup userAgentRoundTripper into util (#2198)
Dedup userAgentRoundTripper into util where userAgent related code are placed.
This commit is contained in:
@@ -30,26 +30,6 @@ import (
|
||||
|
||||
const SourceKind string = "alloydb-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
|
||||
@@ -29,26 +29,6 @@ import (
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
@@ -140,10 +114,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: s.userAgent,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
|
||||
return baseClient, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
|
||||
@@ -29,26 +29,6 @@ import (
|
||||
|
||||
const SourceKind string = "cloud-monitoring"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -98,10 +75,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
|
||||
@@ -30,26 +30,6 @@ import (
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
@@ -119,6 +120,30 @@ func UserAgentFromContext(ctx context.Context) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
type UserAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func NewUserAgentRoundTripper(ua string, next http.RoundTripper) *UserAgentRoundTripper {
|
||||
return &UserAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *UserAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// create a deep copy of the request
|
||||
newReq := req.Clone(req.Context())
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func NewStrictDecoder(v interface{}) (*yaml.Decoder, error) {
|
||||
b, err := yaml.Marshal(v)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user