refactor: decouple Source from Tool (#2204)

This PR update the linking mechanism between Source and Tool.

Tools are directly linked to their Source, either by pointing to the
Source's functions or by assigning values from the source during Tool's
initialization. However, the existing approach means that any
modification to the Source after Tool's initialization might not be
reflected. To address this limitation, each tool should only store a
name reference to the Source, rather than direct link or assigned
values.

Tools will provide interface for `compatibleSource`. This will be used
to determine if a Source is compatible with the Tool.
```
type compatibleSource interface{
    Client() http.Client
    ProjectID() string
}
```

During `Invoke()`, the tool will run the following operations:
* retrieve Source from the `resourceManager` with source's named defined
in Tool's config
* validate Source via `compatibleSource interface{}`
* run the remaining `Invoke()` function. Fields that are needed is
retrieved directly from the source.

With this update, resource manager is also added as input to other
Tool's function that require access to source (e.g.
`RequiresClientAuthorization()`).
This commit is contained in:
Yuan Teoh
2025-12-19 21:27:55 -08:00
committed by GitHub
parent 7daa4111f4
commit 967a72da11
203 changed files with 4242 additions and 6391 deletions

View File

@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(s.ResourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
s.logger.DebugContext(ctx, errMsg.Error())
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
return
}
if clientAuth {
if accessToken == "" {
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
s.logger.DebugContext(ctx, err.Error())
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if tool.RequiresClientAuthorization(s.ResourceMgr) {
if clientAuth {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))

View File

@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return !t.unauthorized
}
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
// defaulted to false
return t.requiresClientAuthrorization
return t.requiresClientAuthrorization, nil
}
func (t MockTool) McpManifest() tools.McpManifest {
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
return mcpManifest
}
func (t MockTool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
return "Authorization", nil
}
// MockPrompt is used to mock prompts in tests

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -107,6 +107,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {

View File

@@ -81,9 +81,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
s := &Source{
Config: r,
BaseURL: "https://monitoring.googleapis.com",
Client: client,
UserAgent: ua,
baseURL: "https://monitoring.googleapis.com",
client: client,
userAgent: ua,
}
return s, nil
}
@@ -92,9 +92,9 @@ var _ sources.Source = &Source{}
type Source struct {
Config
BaseURL string `yaml:"baseUrl"`
Client *http.Client
UserAgent string
baseURL string
client *http.Client
userAgent string
}
func (s *Source) SourceKind() string {
@@ -105,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) BaseURL() string {
return s.baseURL
}
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) UserAgent() string {
return s.userAgent
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
@@ -113,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
token := &oauth2.Token{AccessToken: accessToken}
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
}
return s.Client, nil
return s.client, nil
}
func (s *Source) UseClientAuthorization() bool {

View File

@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
s := &Source{
Config: r,
Client: &client,
client: &client,
}
return s, nil
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
client *http.Client
}
func (s *Source) SourceKind() string {
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) HttpDefaultHeaders() map[string]string {
return s.DefaultHeaders
}
func (s *Source) HttpBaseURL() string {
return s.BaseURL
}
func (s *Source) HttpQueryParams() map[string]string {
return s.QueryParams
}
func (s *Source) Client() *http.Client {
return s.client
}

View File

@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) UseClientAuthorization() bool {
return strings.ToLower(s.UseClientOAuth) != "false"
}
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
return google.DefaultTokenSource(ctx, scope)
}
func (s *Source) LookerClient() *v4.LookerSDK {
return s.Client
}
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) LookerShowHiddenFields() bool {
return s.ShowHiddenFields
}
func (s *Source) LookerShowHiddenModels() bool {
return s.ShowHiddenModels
}
func (s *Source) LookerShowHiddenExplores() bool {
return s.ShowHiddenExplores
}
func (s *Source) LookerSessionLength() int64 {
return s.SessionLength
}
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
if err != nil {

View File

@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProject() string {
return s.Project
}
func (s *Source) GetLocation() string {
return s.Location
}
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-cluster"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
manifest tools.Manifest
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-instance"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-user"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-clusters"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-clusters tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-clusters tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-instances"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-instances tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-users"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-users tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-users tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,9 +25,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-wait-for-operation"
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Config defines the configuration for the wait-for-operation tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the wait-for-operation tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
// Polling configuration
Delay time.Duration
MaxDelay time.Duration
Multiplier float64
MaxRetries int
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'operation' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
@@ -47,11 +46,6 @@ type compatibleSource interface {
PostgresPool() *pgxpool.Pool
}
// validate compatible sources are still compatible
var _ compatibleSource = &alloydbpg.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
numParams := len(cfg.NLConfigParameters)
quotedNameParts := make([]string, 0, numParams)
placeholderParts := make([]string, 0, numParams)
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Config: cfg,
Parameters: cfg.NLConfigParameters,
Statement: stmt,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Parameters parameters.Parameters `yaml:"parameters"`
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
pool := source.PostgresPool()
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
allParamValues[i+2] = fmt.Sprintf("%s", param)
}
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
results, err := pool.Query(ctx, t.Statement, allParamValues...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
}
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke runs the contribution analysis.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
inputData, ok := paramsMap["input_data"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var inputDataSource string
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
inputDataSource = fmt.Sprintf("(%s)", inputData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(inputData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID, datasetID = parts[0], parts[1]
case 2: // dataset.table
projectID, datasetID = t.Client.Project(), parts[0]
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
}
}
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Get session from provider if in protected mode.
// Otherwise, a new session will be created by the first query.
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -26,7 +26,6 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -105,11 +104,6 @@ type CAPayload struct {
ClientIdEnum string `json:"clientIdEnum"`
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// Get cloud-platform token source for Gemini Data Analytics API during initialization
var bigQueryTokenSourceWithScope oauth2.TokenSource
if !s.UseClientAuthorization() {
ctx := context.Background()
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
bigQueryTokenSourceWithScope = ts
}
// finish tool setup
t := Tool{
Config: cfg,
Project: s.BigQueryProject(),
Location: s.BigQueryLocation(),
Parameters: params,
Client: s.BigQueryClient(),
UseClientOAuth: s.UseClientAuthorization(),
TokenSource: bigQueryTokenSourceWithScope,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project string
Location string
Client *bigqueryapi.Client
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
var tokenStr string
var err error
// Get credentials for the API call
if t.UseClientOAuth {
if source.UseClientAuthorization() {
// Use client-side access token
if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} else {
// Get cloud-platform token source for Gemini Data Analytics API during initialization
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
// Use cloud-platform token source for Gemini Data Analytics API
if t.TokenSource == nil {
if tokenSource == nil {
return nil, fmt.Errorf("cloud-platform token source is missing")
}
token, err := t.TokenSource.Token()
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
}
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
for _, tableRef := range tableRefs {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
}
}
}
// Construct URL, headers, and payload
projectID := t.Project
location := t.Location
projectID := source.BigQueryProject()
location := source.BigQueryLocation()
if location == "" {
location = "us"
}
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Call the streaming API
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
if err != nil {
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
}
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
// StreamMessage represents a single message object from the streaming API response.
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
return append(messages, newMessage)
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -60,11 +60,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var sqlDescriptionBuilder strings.Builder
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
WriteMode: s.BigQueryWriteMode(),
SessionProvider: s.BigQuerySession(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
WriteMode string
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var connProps []*bigqueryapi.ConnectionProperty
var session *bigqueryds.Session
if t.WriteMode == bigqueryds.WriteModeProtected {
session, err = t.SessionProvider(ctx)
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
session, err = source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
}
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
statementType := dryRunJob.Statistics.Query.StatementType
switch t.WriteMode {
switch source.BigQueryWriteMode() {
case bigqueryds.WriteModeBlocked:
if statementType != "SELECT" {
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
} else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
parts := strings.Split(tableID, ".")
if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1]
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
}
}
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
SessionProvider: s.BigQuerySession(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, false)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(historyData, ".")
var projectID, datasetID string
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
projectID = parts[0]
datasetID = parts[1]
case 2: // dataset.table
projectID = t.Client.Project()
projectID = source.BigQueryClient().Project()
datasetID = parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
}
}
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -54,11 +54,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
bqClient := t.Client
var err error
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -52,11 +52,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var projectParameter parameters.Parameter
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
if len(t.AllowedDatasets) > 0 {
return t.AllowedDatasets, nil
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
if len(source.BigQueryAllowedDatasets()) > 0 {
return source.BigQueryAllowedDatasets(), nil
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Get the Dataplex client using the method from the source
makeCatalogClient := s.MakeDataplexCatalogClient()
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
MakeCatalogClient: makeCatalogClient,
ProjectID: s.BigQueryProject(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
UseClientOAuth bool
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
pageSize := int32(paramsMap["pageSize"].(int))
prompt, _ := paramsMap["prompt"].(string)
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req := &dataplexpb.SearchEntriesRequest{
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
PageSize: pageSize,
SemanticSearch: true,
}
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
it := catalogClient.SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
}
var results []Response
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
AllParams: allParameters,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
lowLevelParams = append(lowLevelParams, lowLevelParam)
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
query.Location = bqClient.Location
connProps := []*bigqueryapi.ConnectionProperty{}
if t.SessionProvider != nil {
session, err := t.SessionProvider(ctx)
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
"cloud.google.com/go/bigtable"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,11 +45,6 @@ type compatibleSource interface {
BigtableClient() *bigtable.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigtabledb.Source{}
var compatibleSources = [...]string{bigtabledb.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Client: s.BigtableClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigtable.Client
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("fail to get map params: %w", err)
}
ps, err := t.Client.PrepareStatement(
ps, err := source.BigtableClient().PrepareStatement(
ctx,
newStatement,
mapParamsType,
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
gocql "github.com/apache/cassandra-gocql-driver/v2"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,10 +45,6 @@ type compatibleSource interface {
CassandraSession() *gocql.Session
}
var _ compatibleSource = &cassandra.Source{}
var compatibleSources = [...]string{cassandra.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,20 +56,15 @@ type Config struct {
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
}
var _ tools.ToolConfig = Config{}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
// Initialize implements tools.ToolConfig.
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[c.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", c.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
if err != nil {
return nil, err
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
t := Tool{
Config: c,
AllParams: allParameters,
Session: s.CassandraSession(),
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
var _ tools.ToolConfig = Config{}
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Session *gocql.Session
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
// RequiresClientAuthorization implements tools.Tool.
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
// Authorized implements tools.Tool.
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
// Invoke implements tools.Tool.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
// Create a slice to store the out
var out []map[string]interface{}
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
return parameters.ParseParams(t.AllParams, data, claims)
}
var _ tools.Tool = Tool{}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const executeSQLKind string = "clickhouse-execute-sql"
func init() {
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
}
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
params := parameters.Parameters{sqlParameter}
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *sql.DB
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
}
results, err := t.Pool.QueryContext(ctx, sql)
results, err := source.ClickHousePool().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listDatabasesKind string = "clickhouse-list-databases"
func init() {
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Query to list all databases
query := "SHOW DATABASES"
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
}
}
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-databases",
Kind: listDatabasesKind,
Source: "missing-source",
Description: "Test list databases tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listTablesKind string = "clickhouse-list-tables"
const databaseKey string = "database"
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
}
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
params := parameters.Parameters{databaseParameter}
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
database, ok := mapParams[databaseKey].(string)
if !ok {
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Query to list all tables in the specified database
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
}
}
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-tables",
Kind: listTablesKind,
Source: "missing-source",
Description: "Test list tables tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListTables(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,21 +25,15 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const sqlKind string = "clickhouse-sql"
func init() {
if !tools.Register(sqlKind, newSQLConfig) {
if !tools.Register(sqlKind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
}
}
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
}
}
func TestSQLConfigInitializeMissingSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "missing-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
sources := map[string]sources.Source{}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for missing source, got nil")
}
expectedErr := `no source named "missing-source" configured`
if err.Error() != expectedErr {
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
}
}
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
type mockIncompatibleSource struct{}
func (m *mockIncompatibleSource) SourceKind() string {
return "mock"
}
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
return nil
}
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "incompatible-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
mockSource := &mockIncompatibleSource{}
sources := map[string]sources.Source{
"incompatible-source": mockSource,
}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for incompatible source, got nil")
}
if err.Error() == "" {
t.Error("Expected non-empty error message")
}
}
func TestToolManifest(t *testing.T) {
tool := Tool{
manifest: tools.Manifest{

View File

@@ -24,7 +24,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetProjectID() string
GetBaseURL() string
UseClientAuthorization() bool
GetClient(context.Context, string) (*http.Client, error)
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*cloudgdasrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
}
// Define the parameters for the Gemini Data Analytics Query API
// The prompt is the only input parameter.
allParameters := parameters.Parameters{
@@ -87,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
AllParams: allParameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
@@ -99,7 +92,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters
Source *cloudgdasrc.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -110,6 +102,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool logic
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
prompt, ok := paramsMap["prompt"].(string)
if !ok {
@@ -118,11 +115,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
@@ -138,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Parse the access token if provided
var tokenStr string
if t.RequiresClientAuthorization(resourceMgr) {
if source.UseClientAuthorization() {
var err error
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
@@ -146,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
client, err := t.Source.GetClient(ctx, tokenStr)
client, err := source.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
@@ -196,10 +193,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -26,6 +26,7 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
@@ -172,9 +173,8 @@ func TestInitialize(t *testing.T) {
}
tcs := []struct {
desc string
cfg cloudgdatool.Config
expectErr bool
desc string
cfg cloudgdatool.Config
}{
{
desc: "successful initialization",
@@ -185,29 +185,6 @@ func TestInitialize(t *testing.T) {
Description: "Test Description",
Location: "us-central1",
},
expectErr: false,
},
{
desc: "missing source",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "non-existent-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
{
desc: "incompatible source kind",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "incompatible-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
}
@@ -219,16 +196,11 @@ func TestInitialize(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
tool, err := tc.cfg.Initialize(srcs)
if tc.expectErr && err == nil {
t.Fatalf("expected an error but got none")
}
if !tc.expectErr && err != nil {
if err != nil {
t.Fatalf("did not expect an error but got: %v", err)
}
if !tc.expectErr {
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
}
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
})
}
}
@@ -361,8 +333,10 @@ func TestInvoke(t *testing.T) {
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
// Invoke the tool
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
params := parameters.Parameters{urlParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
url, ok := params.AsMap()[pageURLKey].(string)
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
var httpClient *http.Client
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The t.Service object holds a client with the default credentials.
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
svc := t.Service
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
var opts []googleapi.CallOption
if val, ok := params.AsMap()[typeFilterKey]; ok {
types, ok := val.([]any)
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -78,11 +78,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -59,11 +59,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).")
@@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -122,15 +110,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
@@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
}
svc := t.Service
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
@@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(t.AllowedStores) == 0 {
if len(source.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(t.AllowedStores) == 0 {
if len(source.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -61,11 +61,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -127,15 +115,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
@@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -68,11 +68,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -141,15 +129,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
@@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -65,11 +65,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -137,15 +125,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
@@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -63,11 +63,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -93,7 +88,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -113,17 +108,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -133,15 +121,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -149,19 +131,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -171,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
@@ -211,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
BaseURL() string
Client() *http.Client
UserAgent() string
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -60,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*cloudmonitoringsrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind)
}
// Define the parameters internally instead of from the config file.
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true),
@@ -83,9 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
AllParams: allParameters,
BaseURL: s.BaseURL,
UserAgent: s.UserAgent,
Client: s.Client,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
@@ -97,9 +87,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
BaseURL string `yaml:"baseURL"`
UserAgent string
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -109,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
projectID, ok := paramsMap["projectId"].(string)
if !ok {
@@ -119,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("query parameter not found or not a string")
}
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID)
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -130,9 +122,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
q.Add("query", query)
req.URL.RawQuery = q.Encode()
req.Header.Set("User-Agent", t.UserAgent)
req.Header.Set("User-Agent", source.UserAgent())
resp, err := t.Client.Do(req)
resp, err := source.Client().Do(req)
if err != nil {
return nil, err
}
@@ -175,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -81,22 +81,6 @@ func TestInitialize(t *testing.T) {
AuthRequired: []string{"google-auth-service"},
},
},
{
desc: "Error: source not found",
cfg: cloudmonitoring.Config{
Name: "test-tool",
Source: "non-existent-source",
},
wantErr: `no source named "non-existent-source" configured`,
},
{
desc: "Error: incompatible source kind",
cfg: cloudmonitoring.Config{
Name: "test-tool",
Source: "incompatible-source",
},
wantErr: "invalid source for \"cloud-monitoring-query-prometheus\" tool",
},
}
for _, tc := range testCases {

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the clone-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the clone-instance tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -120,6 +123,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -156,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
CloneContext: cloneContext,
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -189,10 +196,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the create-database tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -103,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-database tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -115,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -136,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Instance: instance,
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -169,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the create-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-user tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -149,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
user.Password = password
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -182,10 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-get-instance"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the get-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -65,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("projectId", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -92,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-instances tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -114,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
projectId, ok := paramsMap["projectId"].(string)
@@ -125,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'instanceId' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -158,10 +167,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-list-databases"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the list-databases tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladminsrc.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -91,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Source *cloudsqladminsrc.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -113,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -124,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'instance' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -176,10 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-list-instances"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the list-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladminsrc.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -90,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -101,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
source *cloudsqladminsrc.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -112,6 +116,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -119,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'project' parameter")
}
service, err := t.source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -169,10 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,9 +25,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-wait-for-operation"
@@ -87,6 +87,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the wait-for-operation tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -118,12 +124,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -177,7 +183,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -191,17 +196,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the wait-for-operation tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
// Polling configuration
Delay time.Duration
MaxDelay time.Duration
Multiplier float64
MaxRetries int
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -210,6 +213,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -221,7 +229,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'operation' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -267,7 +275,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("could not unmarshal operation: %w", err)
}
if msg, ok := t.generateCloudSQLConnectionMessage(data); ok {
if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok {
return msg, nil
}
return string(opBytes), nil
@@ -305,11 +313,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) {
func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) {
operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" {
return "", false
@@ -329,7 +341,7 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
instance := matches[2]
database := matches[3]
instanceData, err := t.fetchInstanceData(context.Background(), project, instance)
instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance)
if err != nil {
fmt.Printf("error fetching instance data: %v\n", err)
return "", false
@@ -385,8 +397,8 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
return b.String(), true
}
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
service, err := t.Source.GetService(ctx, "")
func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) {
service, err := source.GetService(ctx, "")
if err != nil {
return nil, err
}
@@ -408,6 +420,6 @@ func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (
return data, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the create-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instances tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Project: project,
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the create-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instances tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Project: project,
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the create-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instances tool.
type Tool struct {
Config
Source *cloudsqladmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Project: project,
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
sqladmin "google.golang.org/api/sqladmin/v1"
@@ -43,6 +42,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
}
// Config defines the configuration for the precheck-upgrade tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -62,15 +66,6 @@ func (cfg Config) ToolConfigKind() string {
// Initialize initializes the tool from the configuration.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*cloudsqladmin.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
}
allParameters := parameters.Parameters{
parameters.NewStringParameter("project", "The project ID"),
parameters.NewStringParameter("instance", "The name of the instance to check"),
@@ -88,28 +83,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
return Tool{
Name: cfg.Name,
Kind: kind,
AuthRequired: cfg.AuthRequired,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
// Tool represents the precheck-upgrade tool.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
Source *cloudsqladmin.Source
Config
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Config
}
// PreCheckResultItem holds the details of a single check result.
@@ -146,6 +132,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -162,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client from source: %w", err)
}
@@ -234,10 +225,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -22,7 +22,6 @@ import (
"github.com/couchbase/gocb/v2"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -48,11 +47,6 @@ type compatibleSource interface {
CouchbaseQueryScanConsistency() uint
}
// validate compatible sources are still compatible
var _ compatibleSource = &couchbase.Source{}
var compatibleSources = [...]string{couchbase.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -92,12 +74,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
// finish tool setup
t := Tool{
Config: cfg,
AllParams: allParameters,
Scope: s.CouchbaseScope(),
QueryScanConsistency: s.CouchbaseQueryScanConsistency(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -107,12 +87,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Scope *gocb.Scope
QueryScanConsistency uint
manifest tools.Manifest
mcpManifest tools.McpManifest
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -120,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
namedParamsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
if err != nil {
@@ -130,8 +112,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency),
results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()),
NamedParameters: newParams.AsMap(),
})
if err != nil {
@@ -166,10 +148,10 @@ func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -118,10 +118,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -22,7 +22,6 @@ import (
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -47,11 +46,6 @@ type compatibleSource interface {
CatalogClient() *dataplexapi.CatalogClient
}
// validate compatible sources are still compatible
var _ compatibleSource = &dataplexds.Source{}
var compatibleSources = [...]string{dataplexds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
viewDesc := `
## Argument: view
@@ -104,9 +87,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
CatalogClient: s.CatalogClient(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -119,10 +101,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
CatalogClient *dataplexapi.CatalogClient
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,6 +111,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
@@ -153,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Entry: entry,
}
result, err := t.CatalogClient.LookupEntry(ctx, req)
result, err := source.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
@@ -179,10 +165,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/cenkalti/backoff/v5"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -49,11 +48,6 @@ type compatibleSource interface {
ProjectID() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &dataplexds.Source{}
var compatibleSources = [...]string{dataplexds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,17 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
query := parameters.NewStringParameter("query", "The query against which aspect type should be matched.")
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of returned aspect types in the search page.")
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
@@ -89,10 +72,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
CatalogClient: s.CatalogClient(),
ProjectID: s.ProjectID(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -105,11 +86,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
CatalogClient *dataplexapi.CatalogClient
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -117,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Invoke the tool with the provided parameters
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
@@ -126,16 +110,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := t.CatalogClient.SearchEntries(ctx, req)
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
@@ -155,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := t.CatalogClient.GetAspectType(ctx, getAspectTypeReq)
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
@@ -192,10 +176,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -22,7 +22,6 @@ import (
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -48,11 +47,6 @@ type compatibleSource interface {
ProjectID() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &dataplexds.Source{}
var compatibleSources = [...]string{dataplexds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
query := parameters.NewStringParameter("query", "The query against which entries in scope should be matched.")
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.")
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
@@ -88,10 +71,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
CatalogClient: s.CatalogClient(),
ProjectID: s.ProjectID(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -104,11 +85,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
CatalogClient *dataplexapi.CatalogClient
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -116,6 +95,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize := int32(paramsMap["pageSize"].(int))
@@ -123,15 +107,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
it := t.CatalogClient.SearchEntries(ctx, req)
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
var results []*dataplexpb.SearchEntriesResult
@@ -163,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -46,11 +46,6 @@ type compatibleSource interface {
DgraphClient() *dgraph.DgraphClient
}
// validate compatible sources are still compatible
var _ compatibleSource = &dgraph.Source{}
var compatibleSources = [...]string{dgraph.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -71,26 +66,13 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil)
// finish tool setup
t := Tool{
Config: cfg,
DgraphClient: s.DgraphClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -100,9 +82,8 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
DgraphClient *dgraph.DgraphClient
manifest tools.Manifest
mcpManifest tools.McpManifest
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -110,9 +91,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMapWithDollarPrefix()
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
if err != nil {
return nil, err
}
@@ -148,10 +134,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -43,10 +43,6 @@ type compatibleSource interface {
ElasticsearchClient() es.EsClient
}
var _ compatibleSource = &es.Source{}
var compatibleSources = [...]string{es.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -77,29 +73,15 @@ type Tool struct {
Config
manifest tools.Manifest
mcpManifest tools.McpManifest
EsClient es.EsClient
}
var _ tools.Tool = Tool{}
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
src, ok := srcs[c.Source]
if !ok {
return nil, fmt.Errorf("source %q not found", c.Source)
}
// verify the source is compatible
s, ok := src.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
mcpManifest := tools.GetMcpManifest(c.Name, c.Description, c.AuthRequired, c.Parameters, nil)
return Tool{
Config: c,
EsClient: s.ElasticsearchClient(),
manifest: tools.Manifest{Description: c.Description, Parameters: c.Parameters.Manifest(), AuthRequired: c.AuthRequired},
mcpManifest: mcpManifest,
}, nil
@@ -120,6 +102,11 @@ type esqlResult struct {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
var cancel context.CancelFunc
if t.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Timeout)*time.Second)
@@ -164,8 +151,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Body: bytes.NewReader(body),
Format: t.Format,
FilterPath: []string{"columns", "values"},
Instrument: t.EsClient.InstrumentationEnabled(),
}.Do(ctx, t.EsClient)
Instrument: source.ElasticsearchClient().InstrumentationEnabled(),
}.Do(ctx, source.ElasticsearchClient())
if err != nil {
return nil, err
@@ -230,10 +217,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -47,10 +46,6 @@ type compatibleSource interface {
FirebirdDB() *sql.DB
}
var _ compatibleSource = &firebird.Source{}
var compatibleSources = [...]string{firebird.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -66,16 +61,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.")
params := parameters.Parameters{sqlParameter}
@@ -84,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Db: s.FirebirdDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -95,9 +79,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Db *sql.DB
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -107,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
@@ -120,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
rows, err := t.Db.QueryContext(ctx, sql)
rows, err := source.FirebirdDB().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -180,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -22,7 +22,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -47,11 +46,6 @@ type compatibleSource interface {
FirebirdDB() *sql.DB
}
// validate compatible sources are still compatible
var _ compatibleSource = &firebird.Source{}
var compatibleSources = [...]string{firebird.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Db: s.FirebirdDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -106,9 +87,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Db *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -142,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
rows, err := t.Db.QueryContext(ctx, statement, namedArgs...)
rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -204,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -50,11 +49,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Create parameters
collectionPathParameter := parameters.NewStringParameter(
collectionPathKey,
@@ -124,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -136,9 +117,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -148,6 +127,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
// Get collection path
@@ -169,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Convert the document data from JSON format to Firestore format
// The client is passed to handle referenceValue types
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
if err != nil {
return nil, fmt.Errorf("failed to convert document data: %w", err)
}
@@ -181,7 +165,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Get the collection reference
collection := t.Client.Collection(collectionPath)
collection := source.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
@@ -221,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,11 +47,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to delete from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
params := parameters.Parameters{documentPathsParameter}
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
if !ok {
@@ -143,14 +127,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := t.Client.BulkWriter(ctx)
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := t.Client.Doc(path)
docRef := source.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
@@ -198,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,11 +47,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to retrieve from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
params := parameters.Parameters{documentPathsParameter}
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
if !ok {
@@ -145,11 +129,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Create document references from paths
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = t.Client.Doc(path)
docRefs[i] = source.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := t.Client.GetAll(ctx, docRefs)
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
@@ -190,10 +174,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/firebaserules/v1"
@@ -48,11 +47,6 @@ type compatibleSource interface {
GetDatabaseId() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// No parameters needed for this tool
params := parameters.Parameters{}
@@ -90,9 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
RulesClient: s.FirebaseRulesClient(),
ProjectId: s.GetProjectId(),
DatabaseId: s.GetDatabaseId(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -104,11 +83,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
RulesClient *firebaserules.Service
ProjectId string
DatabaseId string
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,19 +93,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", t.ProjectId, t.DatabaseId)
release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do()
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", t.ProjectId, t.DatabaseId)
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
@@ -158,10 +138,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,11 +47,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
emptyString := ""
parentPathParameter := parameters.NewStringParameterWithDefault(parentPathKey, emptyString, "Relative parent document path to list subcollections from (e.g., 'users/userId'). If not provided, lists root collections. Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'")
params := parameters.Parameters{parentPathParameter}
@@ -91,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -103,9 +84,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -115,10 +94,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
var collectionRefs []*firestoreapi.CollectionRef
var err error
// Check if parentPath is provided
parentPath, hasParent := mapParams[parentPathKey].(string)
@@ -130,14 +113,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// List subcollections of the specified document
docRef := t.Client.Doc(parentPath)
docRef := source.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = t.Client.Collections(ctx).GetAll()
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
@@ -177,10 +160,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -24,7 +24,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -52,12 +51,9 @@ var validOperators = map[string]bool{
// Error messages
const (
errFilterParseFailed = "failed to parse filters: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errTemplateParseFailed = "failed to parse template: %w"
errTemplateExecFailed = "failed to execute template: %w"
errLimitParseFailed = "failed to parse limit value '%s': %w"
errSelectFieldParseFailed = "failed to parse select field: %w"
errFilterParseFailed = "failed to parse filters: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errLimitParseFailed = "failed to parse limit value '%s': %w"
)
func init() {
@@ -79,11 +75,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
// Config represents the configuration for the Firestore query tool
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -114,18 +105,6 @@ func (cfg Config) ToolConfigKind() string {
// Initialize creates a new Tool instance from the configuration
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Set default limit if not specified
if cfg.Limit == "" {
cfg.Limit = fmt.Sprintf("%d", defaultLimit)
@@ -137,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -201,6 +179,11 @@ type QueryResponse struct {
// Invoke executes the Firestore query based on the provided parameters
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
// Process collection path with template substitution
@@ -210,7 +193,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Build the query
query, err := t.buildQuery(collectionPath, paramsMap)
query, err := t.buildQuery(source, collectionPath, paramsMap)
if err != nil {
return nil, err
}
@@ -220,8 +203,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
collection := t.Client.Collection(collectionPath)
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(collectionPath)
query := collection.Query
// Process and apply filters if template is provided
@@ -239,7 +222,7 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
}
// Convert simplified filter to Firestore filter
if filter := t.convertToFirestoreFilter(simplifiedFilter); filter != nil {
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
query = query.WhereEntity(filter)
}
}
@@ -280,12 +263,12 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
}
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.EntityFilter {
func (t Tool) convertToFirestoreFilter(source compatibleSource, filter SimplifiedFilter) firestoreapi.EntityFilter {
// Handle AND filters
if len(filter.And) > 0 {
filters := make([]firestoreapi.EntityFilter, 0, len(filter.And))
for _, f := range filter.And {
if converted := t.convertToFirestoreFilter(f); converted != nil {
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
filters = append(filters, converted)
}
}
@@ -299,7 +282,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
if len(filter.Or) > 0 {
filters := make([]firestoreapi.EntityFilter, 0, len(filter.Or))
for _, f := range filter.Or {
if converted := t.convertToFirestoreFilter(f); converted != nil {
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
filters = append(filters, converted)
}
}
@@ -313,7 +296,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
if filter.Field != "" && filter.Op != "" && filter.Value != nil {
if validOperators[filter.Op] {
// Convert the value using the Firestore native JSON converter
convertedValue, err := util.JSONToFirestoreValue(filter.Value, t.Client)
convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient())
if err != nil {
// If conversion fails, use the original value
convertedValue = filter.Value
@@ -525,10 +508,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -23,7 +23,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -92,11 +91,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
// Config represents the configuration for the Firestore query collection tool
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -116,18 +110,6 @@ func (cfg Config) ToolConfigKind() string {
// Initialize creates a new Tool instance from the configuration
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Create parameters
params := createParameters()
@@ -137,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -199,9 +180,7 @@ var _ tools.Tool = Tool{}
// Tool represents the Firestore query collection tool
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -266,6 +245,11 @@ type QueryResponse struct {
// Invoke executes the Firestore query based on the provided parameters
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Parse parameters
queryParams, err := t.parseQueryParameters(params)
if err != nil {
@@ -273,7 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Build the query
query, err := t.buildQuery(queryParams)
query, err := t.buildQuery(source, queryParams)
if err != nil {
return nil, err
}
@@ -396,8 +380,8 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) {
collection := t.Client.Collection(params.CollectionPath)
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(params.CollectionPath)
query := collection.Query
// Apply filters
@@ -531,10 +515,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -22,7 +22,6 @@ import (
firestoreapi "cloud.google.com/go/firestore"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -52,11 +51,6 @@ type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -73,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Create parameters
documentPathParameter := parameters.NewStringParameter(
documentPathKey,
@@ -134,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Client: s.FirestoreClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -146,9 +127,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Client *firestoreapi.Client
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -158,6 +137,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
// Get document path
@@ -200,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Get the document reference
docRef := t.Client.Doc(documentPath)
docRef := source.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestoreapi.WriteResult
@@ -211,7 +195,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
updates := make([]firestoreapi.Update, 0, len(updatePaths))
// Convert document data without delete markers
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
if err != nil {
return nil, fmt.Errorf("failed to convert document data: %w", err)
}
@@ -239,7 +223,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
// Update all fields in the document data (merge)
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
if err != nil {
return nil, fmt.Errorf("failed to convert document data: %w", err)
}
@@ -314,10 +298,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -132,32 +132,6 @@ func TestConfig_Initialize(t *testing.T) {
},
wantErr: false,
},
{
name: "source not found",
config: Config{
Name: "test-update-document",
Kind: "firestore-update-document",
Source: "missing-source",
Description: "Update a document",
},
sources: map[string]sources.Source{},
wantErr: true,
errMsg: "no source named \"missing-source\" configured",
},
{
name: "incompatible source",
config: Config{
Name: "test-update-document",
Kind: "firestore-update-document",
Source: "wrong-source",
Description: "Update a document",
},
sources: map[string]sources.Source{
"wrong-source": &mockIncompatibleSource{},
},
wantErr: true,
errMsg: "invalid source for \"firestore-update-document\" tool",
},
}
for _, tt := range tests {
@@ -464,14 +438,3 @@ func TestGetFieldValue(t *testing.T) {
})
}
}
// mockIncompatibleSource is a mock source that doesn't implement compatibleSource
type mockIncompatibleSource struct{}
func (m *mockIncompatibleSource) SourceKind() string {
return "mock"
}
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
return nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/firebaserules/v1"
@@ -53,11 +52,6 @@ type compatibleSource interface {
GetProjectId() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &firestoreds.Source{}
var compatibleSources = [...]string{firestoreds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Create parameters
params := createParameters()
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
@@ -94,8 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
RulesClient: s.FirebaseRulesClient(),
ProjectId: s.GetProjectId(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -117,10 +97,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
RulesClient *firebaserules.Service
ProjectId string
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -154,11 +131,16 @@ type ValidationResult struct {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
// Get source parameter
source, ok := mapParams[sourceKey].(string)
if !ok || source == "" {
sourceParam, ok := mapParams[sourceKey].(string)
if !ok || sourceParam == "" {
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
}
@@ -168,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: source,
Content: sourceParam,
},
},
},
@@ -179,14 +161,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", t.ProjectId)
response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do()
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
result := t.processValidationResponse(response, source)
result := t.processValidationResponse(response, sourceParam)
return result, nil
}
@@ -287,10 +269,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -29,7 +29,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
HttpDefaultHeaders() map[string]string
HttpBaseURL() string
HttpQueryParams() map[string]string
Client() *http.Client
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -81,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
// verify the source is compatible
s, ok := rawS.(*httpsrc.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind)
}
@@ -89,7 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Combine Source and Tool headers.
// In case of conflict, Tool header overrides Source header
combinedHeaders := make(map[string]string)
maps.Copy(combinedHeaders, s.DefaultHeaders)
maps.Copy(combinedHeaders, s.HttpDefaultHeaders())
maps.Copy(combinedHeaders, cfg.Headers)
// Create a slice for all parameters
@@ -113,14 +119,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
BaseURL: s.BaseURL,
Headers: combinedHeaders,
DefaultQueryParams: s.QueryParams,
Client: s.Client,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Headers: combinedHeaders,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
@@ -129,12 +132,8 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
BaseURL string `yaml:"baseURL"`
Headers map[string]string `yaml:"headers"`
DefaultQueryParams map[string]string `yaml:"defaultQueryParams"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *http.Client
Headers map[string]string `yaml:"headers"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -229,6 +228,11 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
// Calculate request body
@@ -238,7 +242,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Calculate URL
urlString, err := getURL(t.BaseURL, t.Path, t.PathParams, t.QueryParams, t.DefaultQueryParams, paramsMap)
urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating path parameters: %s", err)
}
@@ -256,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Make request and fetch response
resp, err := t.Client.Do(req)
resp, err := source.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
@@ -295,10 +299,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
params := lookercommon.GetQueryParameters()
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this tile will exist")
@@ -109,12 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -129,13 +119,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -148,6 +134,11 @@ var (
)
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
@@ -167,12 +158,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
visConfig := paramsMap["vis_config"].(map[string]any)
wq.VisConfig = &visConfig
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
qresp, err := sdk.CreateQuery(*wq, "id", t.ApiSettings)
qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making create query request: %w", err)
}
@@ -239,7 +230,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Fields: &fields,
}
resp, err := sdk.CreateDashboardElement(req, t.ApiSettings)
resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making create dashboard element request: %w", err)
}
@@ -264,14 +255,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
params := parameters.Parameters{}
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this filter will exist")
@@ -109,14 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Name: cfg.Name,
Kind: kind,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Parameters: params,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -131,16 +119,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Name string `yaml:"name"`
Kind string `yaml:"kind"`
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
AuthRequired []string `yaml:"authRequired"`
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -148,6 +129,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
@@ -205,12 +191,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req.Dimension = &dimension
}
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
resp, err := sdk.CreateDashboardFilter(req, "name", t.ApiSettings)
resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making create dashboard filter request: %s", err)
}
@@ -239,10 +225,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -26,7 +26,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookerds "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -56,12 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
GetApiSettings() *rtl.ApiSettings
GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
GoogleCloudProject() string
GoogleCloudLocation() string
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerApiSettings() *rtl.ApiSettings
}
// Structs for building the JSON payload
@@ -124,11 +123,6 @@ type CAPayload struct {
ClientIdEnum string `json:"clientIdEnum"`
}
// validate compatible sources are still compatible
var _ compatibleSource = &lookerds.Source{}
var compatibleSources = [...]string{lookerds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -155,7 +149,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
if s.GoogleCloudProject() == "" {
@@ -196,16 +190,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
ApiSettings: s.GetApiSettings(),
Project: s.GoogleCloudProject(),
Location: s.GoogleCloudLocation(),
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
TokenSource: ts,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
TokenSource: ts,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -215,15 +204,10 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
ApiSettings *rtl.ApiSettings
UseClientOAuth bool `yaml:"useClientOAuth"`
AuthTokenHeaderName string
Parameters parameters.Parameters `yaml:"parameters"`
Project string
Location string
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -231,8 +215,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
var tokenStr string
var err error
// Get credentials for the API call
// Use cloud-platform token source for Gemini Data Analytics API
@@ -253,16 +241,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ler := make([]LookerExploreReference, 0)
for _, er := range exploreReferences {
ler = append(ler, LookerExploreReference{
LookerInstanceUri: t.ApiSettings.BaseUrl,
LookerInstanceUri: source.LookerApiSettings().BaseUrl,
LookmlModel: er.(map[string]any)["model"].(string),
Explore: er.(map[string]any)["explore"].(string),
})
}
oauth_creds := OAuthCredentials{}
if t.UseClientOAuth {
if source.UseClientAuthorization() {
oauth_creds.Token = TokenBased{AccessToken: string(accessToken)}
} else {
oauth_creds.Secret = SecretBased{ClientId: t.ApiSettings.ClientId, ClientSecret: t.ApiSettings.ClientSecret}
oauth_creds.Secret = SecretBased{ClientId: source.LookerApiSettings().ClientId, ClientSecret: source.LookerApiSettings().ClientSecret}
}
lers := LookerExploreReferences{
@@ -273,8 +261,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Construct URL, headers, and payload
projectID := t.Project
location := t.Location
projectID := source.GoogleCloudProject()
location := source.GoogleCloudLocation()
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location))
headers := map[string]string{
@@ -315,12 +303,16 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
// StreamMessage represents a single message object from the streaming API response.
@@ -563,6 +555,10 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
return append(messages, newMessage)
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files")
filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project")
fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file")
@@ -90,12 +84,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -110,13 +100,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -124,7 +110,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
@@ -148,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Content: fileContent,
}
err = lookercommon.CreateProjectFile(sdk, projectId, req, t.ApiSettings)
err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making create_project_file request: %s", err)
}
@@ -172,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files")
filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project")
params := parameters.Parameters{projectIdParameter, filePathParameter}
@@ -91,12 +85,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -111,13 +101,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -125,7 +111,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
@@ -140,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"])
}
err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, t.ApiSettings)
err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making delete_project_file request: %s", err)
}
@@ -164,14 +155,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
devModeParameter := parameters.NewBooleanParameterWithDefault("devMode", true, "Whether to set Dev Mode.")
params := parameters.Parameters{devModeParameter}
@@ -89,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
ShowHiddenExplores: s.ShowHiddenExplores,
mcpManifest: mcpManifest,
}, nil
}
@@ -110,14 +99,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
ShowHiddenExplores bool
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -125,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
@@ -135,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
}
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
@@ -148,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req := v4.WriteApiSession{
WorkspaceId: &devModeString,
}
resp, err := sdk.UpdateSession(req, t.ApiSettings)
resp, err := sdk.UpdateSession(req, source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error setting/resetting dev mode: %w", err)
}
@@ -169,14 +158,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -46,6 +45,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
LookerSessionLength() int64
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -63,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
typeParameter := parameters.NewStringParameterWithDefault("type", "", "Type of Looker content to embed (ie. dashboards, looks, query-visualization)")
idParameter := parameters.NewStringParameterWithDefault("id", "", "The ID of the content to embed.")
params := parameters.Parameters{
@@ -94,19 +89,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
SessionLength: s.SessionLength,
mcpManifest: mcpManifest,
}, nil
}
@@ -115,15 +105,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
AuthRequired []string `yaml:"authRequired"`
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
SessionLength int64
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -131,6 +115,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
@@ -147,16 +136,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
contentId_ptr = nil
}
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
forceLogoutLogin := true
sessionLength := source.LookerSessionLength()
req := v4.EmbedParams{
TargetUrl: fmt.Sprintf("%s/embed/%s/%s", t.ApiSettings.BaseUrl, *embedType_ptr, *contentId_ptr),
SessionLength: &t.SessionLength,
TargetUrl: fmt.Sprintf("%s/embed/%s/%s", source.LookerApiSettings().BaseUrl, *embedType_ptr, *contentId_ptr),
SessionLength: &sessionLength,
ForceLogoutLogin: &forceLogoutLogin,
}
logger.ErrorContext(ctx, "Making request %v", req)
@@ -181,14 +170,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

View File

@@ -19,7 +19,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*lookersrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
}
connParameter := parameters.NewStringParameter("conn", "The connection containing the databases.")
params := parameters.Parameters{connParameter}
@@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
return Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
Client: s.Client,
ApiSettings: s.ApiSettings,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -108,13 +98,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool
AuthTokenHeaderName string
Client *v4.LookerSDK
ApiSettings *rtl.ApiSettings
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -122,17 +108,22 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
conn, ok := mapParams["conn"].(string)
if !ok {
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
}
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}
resp, err := sdk.ConnectionDatabases(conn, t.ApiSettings)
resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings())
if err != nil {
return nil, fmt.Errorf("error making get_connection_databases request: %s", err)
}
@@ -153,14 +144,22 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
}
func (t Tool) GetAuthTokenHeaderName() string {
return t.AuthTokenHeaderName
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return "", err
}
return source.GetAuthTokenHeaderName(), nil
}

Some files were not shown because too many files have changed in this diff Show More