mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 15:14:00 -05:00
feat(source/bigquery): add client cache for user-passed credentials (#1119)
Add client cache and automatic cache cleanup. The cache is managed by a map with OAuth access token as the keys. Upon user tool invocation, get client from existing cache or create a new one.
This commit is contained in:
@@ -98,15 +98,15 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"
|
||||
|
||||
@@ -72,13 +72,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
WriteMode string `yaml:"writeMode"`
|
||||
AllowedDatasets []string `yaml:"allowedDatasets"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
WriteMode string `yaml:"writeMode"`
|
||||
AllowedDatasets []string `yaml:"allowedDatasets"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||
}
|
||||
|
||||
@@ -86,13 +86,16 @@ func (r Config) SourceConfigKind() string {
|
||||
// Returns BigQuery source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
if r.WriteMode == "" {
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
// The protected mode only allows write operations to the session's temporary datasets.
|
||||
// when using client OAuth, a new session is created every
|
||||
// time a BigQuery tool is invoked. Therefore, no session data can
|
||||
// be preserved as needed by the protected mode.
|
||||
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
|
||||
}
|
||||
|
||||
@@ -106,17 +109,38 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var clientCreator BigqueryClientCreator
|
||||
var err error
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Project: r.Project,
|
||||
Location: r.Location,
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
WriteMode: r.WriteMode,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
ClientCreator: clientCreator,
|
||||
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
|
||||
}
|
||||
|
||||
if r.UseClientOAuth {
|
||||
clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
|
||||
// use client OAuth
|
||||
baseClientCreator, err := newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing client creator: %w", err)
|
||||
}
|
||||
setupClientCaching(s, baseClientCreator)
|
||||
|
||||
} else {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location, r.ImpersonateServiceAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from ADC: %w", err)
|
||||
}
|
||||
s.Client = client
|
||||
s.RestService = restService
|
||||
s.TokenSource = tokenSource
|
||||
}
|
||||
|
||||
allowedDatasets := make(map[string]struct{})
|
||||
@@ -138,8 +162,8 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
dataset := client.DatasetInProject(projectID, datasetID)
|
||||
if s.Client != nil {
|
||||
dataset := s.Client.DatasetInProject(projectID, datasetID)
|
||||
_, err := dataset.Metadata(ctx)
|
||||
if err != nil {
|
||||
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
|
||||
@@ -152,21 +176,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
}
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Project: r.Project,
|
||||
Location: r.Location,
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
WriteMode: r.WriteMode,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
ClientCreator: clientCreator,
|
||||
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
|
||||
}
|
||||
s.AllowedDatasets = allowedDatasets
|
||||
s.SessionProvider = s.newBigQuerySessionProvider()
|
||||
|
||||
if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
|
||||
@@ -176,6 +186,58 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// setupClientCaching initializes caches and wraps the base client creator with caching logic.
|
||||
func setupClientCaching(s *Source, baseCreator BigqueryClientCreator) {
|
||||
// Define eviction handlers
|
||||
onBqEvict := func(key string, value interface{}) {
|
||||
if client, ok := value.(*bigqueryapi.Client); ok && client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
onDataplexEvict := func(key string, value interface{}) {
|
||||
if client, ok := value.(*dataplexapi.CatalogClient); ok && client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize caches
|
||||
s.bqClientCache = NewCache(onBqEvict)
|
||||
s.bqRestCache = NewCache(nil)
|
||||
s.dataplexCache = NewCache(onDataplexEvict)
|
||||
|
||||
// Create the caching wrapper for the client creator
|
||||
s.ClientCreator = func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
// Check cache
|
||||
bqClientVal, bqFound := s.bqClientCache.Get(tokenString)
|
||||
|
||||
if wantRestService {
|
||||
restServiceVal, restFound := s.bqRestCache.Get(tokenString)
|
||||
if bqFound && restFound {
|
||||
// Cache hit for both
|
||||
return bqClientVal.(*bigqueryapi.Client), restServiceVal.(*bigqueryrestapi.Service), nil
|
||||
}
|
||||
} else {
|
||||
if bqFound {
|
||||
return bqClientVal.(*bigqueryapi.Client), nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss - call the client creator
|
||||
client, restService, err := baseCreator(tokenString, wantRestService)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Set in cache
|
||||
s.bqClientCache.Set(tokenString, client)
|
||||
if wantRestService && restService != nil {
|
||||
s.bqRestCache.Set(tokenString, restService)
|
||||
}
|
||||
|
||||
return client, restService, nil
|
||||
}
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
@@ -197,6 +259,11 @@ type Source struct {
|
||||
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
|
||||
SessionProvider BigQuerySessionProvider
|
||||
Session *Session
|
||||
|
||||
// Caches for OAuth clients
|
||||
bqClientCache *Cache
|
||||
bqRestCache *Cache
|
||||
dataplexCache *Cache
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
@@ -397,7 +464,29 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
|
||||
return
|
||||
}
|
||||
client = c
|
||||
clientCreator = cc
|
||||
|
||||
// If using OAuth, wrap the provided client creator (cc) with caching logic
|
||||
if s.UseClientOAuth && cc != nil {
|
||||
clientCreator = func(tokenString string) (*dataplexapi.CatalogClient, error) {
|
||||
// Check cache
|
||||
if val, found := s.dataplexCache.Get(tokenString); found {
|
||||
return val.(*dataplexapi.CatalogClient), nil
|
||||
}
|
||||
|
||||
// Cache miss - call client creator
|
||||
dpClient, err := cc(tokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set in cache
|
||||
s.dataplexCache.Set(tokenString, dpClient)
|
||||
return dpClient, nil
|
||||
}
|
||||
} else {
|
||||
// Not using OAuth or no creator was returned
|
||||
clientCreator = cc
|
||||
}
|
||||
})
|
||||
return client, clientCreator, err
|
||||
}
|
||||
|
||||
125
internal/sources/bigquery/cache.go
Normal file
125
internal/sources/bigquery/cache.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package bigquery
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Item holds the cached value and its expiration timestamp
|
||||
type Item struct {
|
||||
Value any
|
||||
ExpiresAt int64 // Unix nano timestamp
|
||||
}
|
||||
|
||||
// IsExpired checks if the item is expired
|
||||
func (item Item) IsExpired() bool {
|
||||
return time.Now().UnixNano() > item.ExpiresAt
|
||||
}
|
||||
|
||||
// OnEvictFunc is the signature for the callback
|
||||
type OnEvictFunc func(key string, value any)
|
||||
|
||||
// Cache is a thread-safe, expiring key-value store
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]Item
|
||||
onEvict OnEvictFunc
|
||||
}
|
||||
|
||||
// NewCache creates a new cache and cleans up every 55 min
|
||||
func NewCache(onEvict OnEvictFunc) *Cache {
|
||||
const cleanupInterval = 55 * time.Minute
|
||||
|
||||
c := &Cache{
|
||||
items: make(map[string]Item),
|
||||
onEvict: onEvict,
|
||||
}
|
||||
|
||||
go c.startCleanup(cleanupInterval)
|
||||
return c
|
||||
}
|
||||
|
||||
// startCleanup runs a ticker to periodically delete expired items
|
||||
func (c *Cache) startCleanup(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.DeleteExpired()
|
||||
}
|
||||
}
|
||||
|
||||
// delete is an internal helper that assumes the write lock is held
|
||||
func (c *Cache) delete(key string, item Item) {
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(key, item.Value)
|
||||
}
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
func (c *Cache) Set(key string, value any) {
|
||||
const ttl = 55 * time.Minute
|
||||
expires := time.Now().Add(ttl).UnixNano()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// If item already exists, evict the old one before replacing
|
||||
if oldItem, found := c.items[key]; found {
|
||||
c.delete(key, oldItem)
|
||||
}
|
||||
|
||||
c.items[key] = Item{
|
||||
Value: value,
|
||||
ExpiresAt: expires,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *Cache) Get(key string) (any, bool) {
|
||||
c.mu.RLock()
|
||||
item, found := c.items[key]
|
||||
if !found || item.IsExpired() {
|
||||
c.mu.RUnlock()
|
||||
return nil, false
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete manually evicts an item
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if item, found := c.items[key]; found {
|
||||
c.delete(key, item)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteExpired removes all expired items
|
||||
func (c *Cache) DeleteExpired() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for key, item := range c.items {
|
||||
if item.IsExpired() {
|
||||
c.delete(key, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user