feat(source/bigquery): add client cache for user-passed credentials (#1119)

Add client cache and automatic cache cleanup. 
The cache is managed by a map with OAuth access token as the keys.
Upon user tool invocation, get client from existing cache or create a
new one.
This commit is contained in:
Wenxin Du
2025-11-04 17:16:44 -05:00
committed by GitHub
parent 76d626e43b
commit cf7012a82b
3 changed files with 243 additions and 29 deletions

View File

@@ -98,15 +98,15 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"

View File

@@ -72,13 +72,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
AllowedDatasets []string `yaml:"allowedDatasets"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
AllowedDatasets []string `yaml:"allowedDatasets"`
UseClientOAuth bool `yaml:"useClientOAuth"`
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
}
@@ -86,13 +86,16 @@ func (r Config) SourceConfigKind() string {
// Returns BigQuery source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
r.WriteMode = WriteModeAllowed
}
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
// The protected mode only allows write operations to the session's temporary datasets.
// when using client OAuth, a new session is created every
// time a BigQuery tool is invoked. Therefore, no session data can
// be preserved as needed by the protected mode.
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
}
@@ -106,17 +109,38 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var clientCreator BigqueryClientCreator
var err error
s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
RestService: restService,
TokenSource: tokenSource,
MaxQueryResultRows: 50,
WriteMode: r.WriteMode,
UseClientOAuth: r.UseClientOAuth,
ClientCreator: clientCreator,
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
}
if r.UseClientOAuth {
clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
// use client OAuth
baseClientCreator, err := newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
if err != nil {
return nil, fmt.Errorf("error constructing client creator: %w", err)
}
setupClientCaching(s, baseClientCreator)
} else {
// Initializes a BigQuery Google SQL source
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location, r.ImpersonateServiceAccount)
if err != nil {
return nil, fmt.Errorf("error creating client from ADC: %w", err)
}
s.Client = client
s.RestService = restService
s.TokenSource = tokenSource
}
allowedDatasets := make(map[string]struct{})
@@ -138,8 +162,8 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
}
if client != nil {
dataset := client.DatasetInProject(projectID, datasetID)
if s.Client != nil {
dataset := s.Client.DatasetInProject(projectID, datasetID)
_, err := dataset.Metadata(ctx)
if err != nil {
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
@@ -152,21 +176,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
}
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
RestService: restService,
TokenSource: tokenSource,
MaxQueryResultRows: 50,
WriteMode: r.WriteMode,
AllowedDatasets: allowedDatasets,
UseClientOAuth: r.UseClientOAuth,
ClientCreator: clientCreator,
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
}
s.AllowedDatasets = allowedDatasets
s.SessionProvider = s.newBigQuerySessionProvider()
if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
@@ -176,6 +186,58 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// setupClientCaching initializes caches and wraps the base client creator with caching logic.
func setupClientCaching(s *Source, baseCreator BigqueryClientCreator) {
// Define eviction handlers
onBqEvict := func(key string, value interface{}) {
if client, ok := value.(*bigqueryapi.Client); ok && client != nil {
client.Close()
}
}
onDataplexEvict := func(key string, value interface{}) {
if client, ok := value.(*dataplexapi.CatalogClient); ok && client != nil {
client.Close()
}
}
// Initialize caches
s.bqClientCache = NewCache(onBqEvict)
s.bqRestCache = NewCache(nil)
s.dataplexCache = NewCache(onDataplexEvict)
// Create the caching wrapper for the client creator
s.ClientCreator = func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
// Check cache
bqClientVal, bqFound := s.bqClientCache.Get(tokenString)
if wantRestService {
restServiceVal, restFound := s.bqRestCache.Get(tokenString)
if bqFound && restFound {
// Cache hit for both
return bqClientVal.(*bigqueryapi.Client), restServiceVal.(*bigqueryrestapi.Service), nil
}
} else {
if bqFound {
return bqClientVal.(*bigqueryapi.Client), nil, nil
}
}
// Cache miss - call the client creator
client, restService, err := baseCreator(tokenString, wantRestService)
if err != nil {
return nil, nil, err
}
// Set in cache
s.bqClientCache.Set(tokenString, client)
if wantRestService && restService != nil {
s.bqRestCache.Set(tokenString, restService)
}
return client, restService, nil
}
}
var _ sources.Source = &Source{}
type Source struct {
@@ -197,6 +259,11 @@ type Source struct {
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
SessionProvider BigQuerySessionProvider
Session *Session
// Caches for OAuth clients
bqClientCache *Cache
bqRestCache *Cache
dataplexCache *Cache
}
type Session struct {
@@ -397,7 +464,29 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
return
}
client = c
clientCreator = cc
// If using OAuth, wrap the provided client creator (cc) with caching logic
if s.UseClientOAuth && cc != nil {
clientCreator = func(tokenString string) (*dataplexapi.CatalogClient, error) {
// Check cache
if val, found := s.dataplexCache.Get(tokenString); found {
return val.(*dataplexapi.CatalogClient), nil
}
// Cache miss - call client creator
dpClient, err := cc(tokenString)
if err != nil {
return nil, err
}
// Set in cache
s.dataplexCache.Set(tokenString, dpClient)
return dpClient, nil
}
} else {
// Not using OAuth or no creator was returned
clientCreator = cc
}
})
return client, clientCreator, err
}

View File

@@ -0,0 +1,125 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquery
import (
"sync"
"time"
)
// Item holds the cached value and its expiration timestamp
type Item struct {
Value any
ExpiresAt int64 // Unix nano timestamp
}
// IsExpired checks if the item is expired
func (item Item) IsExpired() bool {
return time.Now().UnixNano() > item.ExpiresAt
}
// OnEvictFunc is the signature for the callback
type OnEvictFunc func(key string, value any)
// Cache is a thread-safe, expiring key-value store
type Cache struct {
mu sync.RWMutex
items map[string]Item
onEvict OnEvictFunc
}
// NewCache creates a new cache and cleans up every 55 min
func NewCache(onEvict OnEvictFunc) *Cache {
const cleanupInterval = 55 * time.Minute
c := &Cache{
items: make(map[string]Item),
onEvict: onEvict,
}
go c.startCleanup(cleanupInterval)
return c
}
// startCleanup runs a ticker to periodically delete expired items
func (c *Cache) startCleanup(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
c.DeleteExpired()
}
}
// delete is an internal helper that assumes the write lock is held
func (c *Cache) delete(key string, item Item) {
if c.onEvict != nil {
c.onEvict(key, item.Value)
}
delete(c.items, key)
}
// Set adds an item to the cache
func (c *Cache) Set(key string, value any) {
const ttl = 55 * time.Minute
expires := time.Now().Add(ttl).UnixNano()
c.mu.Lock()
defer c.mu.Unlock()
// If item already exists, evict the old one before replacing
if oldItem, found := c.items[key]; found {
c.delete(key, oldItem)
}
c.items[key] = Item{
Value: value,
ExpiresAt: expires,
}
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (any, bool) {
c.mu.RLock()
item, found := c.items[key]
if !found || item.IsExpired() {
c.mu.RUnlock()
return nil, false
}
c.mu.RUnlock()
return item.Value, true
}
// Delete manually evicts an item
func (c *Cache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
if item, found := c.items[key]; found {
c.delete(key, item)
}
}
// DeleteExpired removes all expired items
func (c *Cache) DeleteExpired() {
c.mu.Lock()
defer c.mu.Unlock()
for key, item := range c.items {
if item.IsExpired() {
c.delete(key, item)
}
}
}