mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
refactor: decouple Source from Tool (#2204)
This PR update the linking mechanism between Source and Tool.
Tools are directly linked to their Source, either by pointing to the
Source's functions or by assigning values from the source during Tool's
initialization. However, the existing approach means that any
modification to the Source after Tool's initialization might not be
reflected. To address this limitation, each tool should only store a
name reference to the Source, rather than direct link or assigned
values.
Tools will provide interface for `compatibleSource`. This will be used
to determine if a Source is compatible with the Tool.
```
type compatibleSource interface{
Client() http.Client
ProjectID() string
}
```
During `Invoke()`, the tool will run the following operations:
* retrieve Source from the `resourceManager` with source's named defined
in Tool's config
* validate Source via `compatibleSource interface{}`
* run the remaining `Invoke()` function. Fields that are needed is
retrieved directly from the source.
With this update, resource manager is also added as input to other
Tool's function that require access to source (e.g.
`RequiresClientAuthorization()`).
This commit is contained in:
@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
s.logger.DebugContext(ctx, errMsg.Error())
|
||||
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
if clientAuth {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
|
||||
@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
return !t.unauthorized
|
||||
}
|
||||
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||
// defaulted to false
|
||||
return t.requiresClientAuthrorization
|
||||
return t.requiresClientAuthrorization, nil
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
|
||||
return mcpManifest
|
||||
}
|
||||
|
||||
func (t MockTool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
// MockPrompt is used to mock prompts in tests
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -107,6 +107,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProjectID() string {
|
||||
return s.ProjectID
|
||||
}
|
||||
|
||||
func (s *Source) GetBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
|
||||
@@ -81,9 +81,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
BaseURL: "https://monitoring.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
baseURL: "https://monitoring.googleapis.com",
|
||||
client: client,
|
||||
userAgent: ua,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -92,9 +92,9 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -105,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) BaseURL() string {
|
||||
return s.baseURL
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) UserAgent() string {
|
||||
return s.userAgent
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
@@ -113,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
return s.client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
|
||||
@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: &client,
|
||||
client: &client,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) HttpDefaultHeaders() map[string]string {
|
||||
return s.DefaultHeaders
|
||||
}
|
||||
|
||||
func (s *Source) HttpBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) HttpQueryParams() map[string]string {
|
||||
return s.QueryParams
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return strings.ToLower(s.UseClientOAuth) != "false"
|
||||
}
|
||||
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
|
||||
return google.DefaultTokenSource(ctx, scope)
|
||||
}
|
||||
|
||||
func (s *Source) LookerClient() *v4.LookerSDK {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenFields() bool {
|
||||
return s.ShowHiddenFields
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenModels() bool {
|
||||
return s.ShowHiddenModels
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenExplores() bool {
|
||||
return s.ShowHiddenExplores
|
||||
}
|
||||
|
||||
func (s *Source) LookerSessionLength() int64 {
|
||||
return s.SessionLength
|
||||
}
|
||||
|
||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||
if err != nil {
|
||||
|
||||
@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProject() string {
|
||||
return s.Project
|
||||
}
|
||||
|
||||
func (s *Source) GetLocation() string {
|
||||
return s.Location
|
||||
}
|
||||
|
||||
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-cluster"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-instance"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-user"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-clusters"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-clusters tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-clusters tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-instances"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-users"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-users tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-users tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-wait-for-operation"
|
||||
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
numParams := len(cfg.NLConfigParameters)
|
||||
quotedNameParts := make([]string, 0, numParams)
|
||||
placeholderParts := make([]string, 0, numParams)
|
||||
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Config: cfg,
|
||||
Parameters: cfg.NLConfigParameters,
|
||||
Statement: stmt,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool := source.PostgresPool()
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
allParamValues := make([]any, len(sliceParams)+1)
|
||||
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
|
||||
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
allParamValues[i+2] = fmt.Sprintf("%s", param)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
|
||||
results, err := pool.Query(ctx, t.Statement, allParamValues...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
|
||||
}
|
||||
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(inputData, ".")
|
||||
var projectID, datasetID string
|
||||
switch len(parts) {
|
||||
case 3: // project.dataset.table
|
||||
projectID, datasetID = parts[0], parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID, datasetID = t.Client.Project(), parts[0]
|
||||
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
||||
}
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
||||
}
|
||||
}
|
||||
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -105,11 +104,6 @@ type CAPayload struct {
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
var bigQueryTokenSourceWithScope oauth2.TokenSource
|
||||
if !s.UseClientAuthorization() {
|
||||
ctx := context.Background()
|
||||
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
bigQueryTokenSourceWithScope = ts
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Project: s.BigQueryProject(),
|
||||
Location: s.BigQueryLocation(),
|
||||
Parameters: params,
|
||||
Client: s.BigQueryClient(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
TokenSource: bigQueryTokenSourceWithScope,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project string
|
||||
Location string
|
||||
Client *bigqueryapi.Client
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
MaxQueryResultRows int
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
if t.TokenSource == nil {
|
||||
if tokenSource == nil {
|
||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
||||
}
|
||||
token, err := t.TokenSource.Token()
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
||||
}
|
||||
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
for _, tableRef := range tableRefs {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
projectID := source.BigQueryProject()
|
||||
location := source.BigQueryLocation()
|
||||
if location == "" {
|
||||
location = "us"
|
||||
}
|
||||
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
|
||||
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||
}
|
||||
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -60,11 +60,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var sqlDescriptionBuilder strings.Builder
|
||||
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
WriteMode: s.BigQueryWriteMode(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
WriteMode string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
var session *bigqueryds.Session
|
||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||
session, err = t.SessionProvider(ctx)
|
||||
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
||||
session, err = source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
}
|
||||
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
switch t.WriteMode {
|
||||
switch source.BigQueryWriteMode() {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
||||
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
} else if statementType != "SELECT" {
|
||||
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
||||
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
||||
if parseErr != nil {
|
||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
parts := strings.Split(tableID, ".")
|
||||
if len(parts) == 3 {
|
||||
projectID, datasetID := parts[0], parts[1]
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||
}
|
||||
}
|
||||
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
historyData, ok := paramsMap["history_data"].(string)
|
||||
if !ok {
|
||||
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var historyDataSource string
|
||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(historyData, ".")
|
||||
var projectID, datasetID string
|
||||
|
||||
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
projectID = parts[0]
|
||||
datasetID = parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID = t.Client.Project()
|
||||
projectID = source.BigQueryClient().Project()
|
||||
datasetID = parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
||||
}
|
||||
}
|
||||
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -54,11 +54,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -52,11 +52,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var projectParameter parameters.Parameter
|
||||
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
return t.AllowedDatasets, nil
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
return source.BigQueryAllowedDatasets(), nil
|
||||
}
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Get the Dataplex client using the method from the source
|
||||
makeCatalogClient := s.MakeDataplexCatalogClient()
|
||||
|
||||
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
|
||||
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
|
||||
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
|
||||
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
MakeCatalogClient: makeCatalogClient,
|
||||
ProjectID: s.BigQueryProject(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
UseClientOAuth bool
|
||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
|
||||
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
prompt, _ := paramsMap["prompt"].(string)
|
||||
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
|
||||
PageSize: pageSize,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
|
||||
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
|
||||
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
it := catalogClient.SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
|
||||
}
|
||||
|
||||
var results []Response
|
||||
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||
|
||||
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
query.Location = bqClient.Location
|
||||
|
||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||
if t.SessionProvider != nil {
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if source.BigQuerySession() != nil {
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"cloud.google.com/go/bigtable"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,11 +45,6 @@ type compatibleSource interface {
|
||||
BigtableClient() *bigtable.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigtabledb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigtabledb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigtable.Client
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
ps, err := t.Client.PrepareStatement(
|
||||
ps, err := source.BigtableClient().PrepareStatement(
|
||||
ctx,
|
||||
newStatement,
|
||||
mapParamsType,
|
||||
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,10 +45,6 @@ type compatibleSource interface {
|
||||
CassandraSession() *gocql.Session
|
||||
}
|
||||
|
||||
var _ compatibleSource = &cassandra.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cassandra.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,20 +56,15 @@ type Config struct {
|
||||
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
t := Tool{
|
||||
Config: c,
|
||||
AllParams: allParameters,
|
||||
Session: s.CassandraSession(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Session *gocql.Session
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// RequiresClientAuthorization implements tools.Tool.
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const executeSQLKind string = "clickhouse-execute-sql"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
|
||||
params := parameters.Parameters{sqlParameter}
|
||||
|
||||
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listDatabasesKind string = "clickhouse-list-databases"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query to list all databases
|
||||
query := "SHOW DATABASES"
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-databases",
|
||||
Kind: listDatabasesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list databases tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listTablesKind string = "clickhouse-list-tables"
|
||||
const databaseKey string = "database"
|
||||
|
||||
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
|
||||
}
|
||||
|
||||
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
|
||||
params := parameters.Parameters{databaseParameter}
|
||||
|
||||
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
database, ok := mapParams[databaseKey].(string)
|
||||
if !ok {
|
||||
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Query to list all tables in the specified database
|
||||
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-tables",
|
||||
Kind: listTablesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list tables tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListTables(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,21 +25,15 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const sqlKind string = "clickhouse-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(sqlKind, newSQLConfig) {
|
||||
if !tools.Register(sqlKind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeMissingSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
sources := map[string]sources.Source{}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing source, got nil")
|
||||
}
|
||||
|
||||
expectedErr := `no source named "missing-source" configured`
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
|
||||
type mockIncompatibleSource struct{}
|
||||
|
||||
func (m *mockIncompatibleSource) SourceKind() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "incompatible-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
mockSource := &mockIncompatibleSource{}
|
||||
|
||||
sources := map[string]sources.Source{
|
||||
"incompatible-source": mockSource,
|
||||
}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for incompatible source, got nil")
|
||||
}
|
||||
|
||||
if err.Error() == "" {
|
||||
t.Error("Expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolManifest(t *testing.T) {
|
||||
tool := Tool{
|
||||
manifest: tools.Manifest{
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetProjectID() string
|
||||
GetBaseURL() string
|
||||
UseClientAuthorization() bool
|
||||
GetClient(context.Context, string) (*http.Client, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudgdasrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
@@ -87,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -99,7 +92,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters
|
||||
Source *cloudgdasrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -110,6 +102,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
if !ok {
|
||||
@@ -118,11 +115,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
|
||||
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
@@ -138,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Parse the access token if provided
|
||||
var tokenStr string
|
||||
if t.RequiresClientAuthorization(resourceMgr) {
|
||||
if source.UseClientAuthorization() {
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
@@ -146,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, tokenStr)
|
||||
client, err := source.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
@@ -196,10 +193,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
@@ -172,9 +173,8 @@ func TestInitialize(t *testing.T) {
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
expectErr bool
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
}{
|
||||
{
|
||||
desc: "successful initialization",
|
||||
@@ -185,29 +185,6 @@ func TestInitialize(t *testing.T) {
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "incompatible source kind",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "incompatible-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -219,16 +196,11 @@ func TestInitialize(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool, err := tc.cfg.Initialize(srcs)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
if err != nil {
|
||||
t.Fatalf("did not expect an error but got: %v", err)
|
||||
}
|
||||
if !tc.expectErr {
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
}
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -361,8 +333,10 @@ func TestInvoke(t *testing.T) {
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
|
||||
params := parameters.Parameters{urlParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, ok := params.AsMap()[pageURLKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The t.Service object holds a client with the default credentials.
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
|
||||
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
||||
var opts []googleapi.CallOption
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -78,11 +78,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -59,11 +59,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).")
|
||||
@@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -122,15 +110,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
@@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -61,11 +61,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -127,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
@@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -68,11 +68,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -141,15 +129,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
||||
@@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -65,11 +65,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -137,15 +125,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
@@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -63,11 +63,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -93,7 +88,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -113,17 +108,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -133,15 +121,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -149,19 +131,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -171,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
|
||||
@@ -211,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BaseURL() string
|
||||
Client() *http.Client
|
||||
UserAgent() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -60,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudmonitoringsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters internally instead of from the config file.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true),
|
||||
@@ -83,9 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
BaseURL: s.BaseURL,
|
||||
UserAgent: s.UserAgent,
|
||||
Client: s.Client,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -97,9 +87,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
UserAgent string
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -109,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
projectID, ok := paramsMap["projectId"].(string)
|
||||
if !ok {
|
||||
@@ -119,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID)
|
||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -130,9 +122,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
q.Add("query", query)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
req.Header.Set("User-Agent", t.UserAgent)
|
||||
req.Header.Set("User-Agent", source.UserAgent())
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -81,22 +81,6 @@ func TestInitialize(t *testing.T) {
|
||||
AuthRequired: []string{"google-auth-service"},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Error: source not found",
|
||||
cfg: cloudmonitoring.Config{
|
||||
Name: "test-tool",
|
||||
Source: "non-existent-source",
|
||||
},
|
||||
wantErr: `no source named "non-existent-source" configured`,
|
||||
},
|
||||
{
|
||||
desc: "Error: incompatible source kind",
|
||||
cfg: cloudmonitoring.Config{
|
||||
Name: "test-tool",
|
||||
Source: "incompatible-source",
|
||||
},
|
||||
wantErr: "invalid source for \"cloud-monitoring-query-prometheus\" tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the clone-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the clone-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -120,6 +123,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -156,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
CloneContext: cloneContext,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,10 +196,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-database tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -103,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-database tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -115,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -136,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Instance: instance,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -149,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
user.Password = password
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,10 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-get-instance"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the get-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("projectId", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -92,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -114,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
projectId, ok := paramsMap["projectId"].(string)
|
||||
@@ -125,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'instanceId' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,10 +167,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-list-databases"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the list-databases tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladminsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -91,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Source *cloudsqladminsrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -113,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -124,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'instance' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -176,10 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-list-instances"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the list-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladminsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -90,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -101,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
source *cloudsqladminsrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -112,6 +116,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -119,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
}
|
||||
|
||||
service, err := t.source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,10 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-wait-for-operation"
|
||||
@@ -87,6 +87,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -118,12 +124,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -177,7 +183,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -191,17 +196,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -210,6 +213,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -221,7 +229,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -267,7 +275,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("could not unmarshal operation: %w", err)
|
||||
}
|
||||
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(data); ok {
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok {
|
||||
return msg, nil
|
||||
}
|
||||
return string(opBytes), nil
|
||||
@@ -305,11 +313,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) {
|
||||
func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
return "", false
|
||||
@@ -329,7 +341,7 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
|
||||
instance := matches[2]
|
||||
database := matches[3]
|
||||
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), project, instance)
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance)
|
||||
if err != nil {
|
||||
fmt.Printf("error fetching instance data: %v\n", err)
|
||||
return "", false
|
||||
@@ -385,8 +397,8 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
|
||||
service, err := t.Source.GetService(ctx, "")
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) {
|
||||
service, err := source.GetService(ctx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -408,6 +420,6 @@ func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the precheck-upgrade tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -62,15 +66,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameter("project", "The project ID"),
|
||||
parameters.NewStringParameter("instance", "The name of the instance to check"),
|
||||
@@ -88,28 +83,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the precheck-upgrade tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Source *cloudsqladmin.Source
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Config
|
||||
}
|
||||
|
||||
// PreCheckResultItem holds the details of a single check result.
|
||||
@@ -146,6 +132,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -162,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client from source: %w", err)
|
||||
}
|
||||
@@ -234,10 +225,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/couchbase/gocb/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
CouchbaseQueryScanConsistency() uint
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &couchbase.Source{}
|
||||
|
||||
var compatibleSources = [...]string{couchbase.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -92,12 +74,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Scope: s.CouchbaseScope(),
|
||||
QueryScanConsistency: s.CouchbaseQueryScanConsistency(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -107,12 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Scope *gocb.Scope
|
||||
QueryScanConsistency uint
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -120,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namedParamsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
|
||||
if err != nil {
|
||||
@@ -130,8 +112,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{
|
||||
ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency),
|
||||
results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{
|
||||
ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()),
|
||||
NamedParameters: newParams.AsMap(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -166,10 +148,10 @@ func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -118,10 +118,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
viewDesc := `
|
||||
## Argument: view
|
||||
|
||||
@@ -104,9 +87,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -119,10 +101,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,6 +111,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
@@ -153,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := t.CatalogClient.LookupEntry(ctx, req)
|
||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,10 +165,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -49,11 +48,6 @@ type compatibleSource interface {
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,17 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
query := parameters.NewStringParameter("query", "The query against which aspect type should be matched.")
|
||||
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of returned aspect types in the search page.")
|
||||
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
|
||||
@@ -89,10 +72,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
ProjectID: s.ProjectID(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -105,11 +86,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -117,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the tool with the provided parameters
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
@@ -126,16 +110,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := t.CatalogClient.SearchEntries(ctx, req)
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
@@ -155,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := t.CatalogClient.GetAspectType(ctx, getAspectTypeReq)
|
||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
@@ -192,10 +176,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
query := parameters.NewStringParameter("query", "The query against which entries in scope should be matched.")
|
||||
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.")
|
||||
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
|
||||
@@ -88,10 +71,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
ProjectID: s.ProjectID(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -104,11 +85,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -116,6 +95,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
@@ -123,15 +107,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
it := t.CatalogClient.SearchEntries(ctx, req)
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
@@ -163,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -46,11 +46,6 @@ type compatibleSource interface {
|
||||
DgraphClient() *dgraph.DgraphClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dgraph.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dgraph.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,26 +66,13 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
DgraphClient: s.DgraphClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -100,9 +82,8 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
DgraphClient *dgraph.DgraphClient
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -110,9 +91,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMapWithDollarPrefix()
|
||||
|
||||
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,10 +134,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -43,10 +43,6 @@ type compatibleSource interface {
|
||||
ElasticsearchClient() es.EsClient
|
||||
}
|
||||
|
||||
var _ compatibleSource = &es.Source{}
|
||||
|
||||
var compatibleSources = [...]string{es.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -77,29 +73,15 @@ type Tool struct {
|
||||
Config
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
EsClient es.EsClient
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
src, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := src.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(c.Name, c.Description, c.AuthRequired, c.Parameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: c,
|
||||
EsClient: s.ElasticsearchClient(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: c.Parameters.Manifest(), AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -120,6 +102,11 @@ type esqlResult struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
if t.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Timeout)*time.Second)
|
||||
@@ -164,8 +151,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Body: bytes.NewReader(body),
|
||||
Format: t.Format,
|
||||
FilterPath: []string{"columns", "values"},
|
||||
Instrument: t.EsClient.InstrumentationEnabled(),
|
||||
}.Do(ctx, t.EsClient)
|
||||
Instrument: source.ElasticsearchClient().InstrumentationEnabled(),
|
||||
}.Do(ctx, source.ElasticsearchClient())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -230,10 +217,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -47,10 +46,6 @@ type compatibleSource interface {
|
||||
FirebirdDB() *sql.DB
|
||||
}
|
||||
|
||||
var _ compatibleSource = &firebird.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firebird.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -66,16 +61,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.")
|
||||
params := parameters.Parameters{sqlParameter}
|
||||
|
||||
@@ -84,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Db: s.FirebirdDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -95,9 +79,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Db *sql.DB
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -107,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
@@ -120,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||
|
||||
rows, err := t.Db.QueryContext(ctx, sql)
|
||||
rows, err := source.FirebirdDB().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -180,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
FirebirdDB() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firebird.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firebird.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Db: s.FirebirdDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -106,9 +87,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Db *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -142,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := t.Db.QueryContext(ctx, statement, namedArgs...)
|
||||
rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -204,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -50,11 +49,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
collectionPathParameter := parameters.NewStringParameter(
|
||||
collectionPathKey,
|
||||
@@ -124,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -136,9 +117,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -148,6 +127,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
@@ -169,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Convert the document data from JSON format to Firestore format
|
||||
// The client is passed to handle referenceValue types
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -181,7 +165,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Get the collection reference
|
||||
collection := t.Client.Collection(collectionPath)
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
@@ -221,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to delete from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
|
||||
params := parameters.Parameters{documentPathsParameter}
|
||||
|
||||
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
@@ -143,14 +127,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := t.Client.BulkWriter(ctx)
|
||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := t.Client.Doc(path)
|
||||
docRef := source.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
@@ -198,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to retrieve from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
|
||||
params := parameters.Parameters{documentPathsParameter}
|
||||
|
||||
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
@@ -145,11 +129,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = t.Client.Doc(path)
|
||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := t.Client.GetAll(ctx, docRefs)
|
||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
@@ -190,10 +174,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
GetDatabaseId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// No parameters needed for this tool
|
||||
params := parameters.Parameters{}
|
||||
|
||||
@@ -90,9 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
DatabaseId: s.GetDatabaseId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -104,11 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
DatabaseId string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,19 +93,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", t.ProjectId, t.DatabaseId)
|
||||
release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", t.ProjectId, t.DatabaseId)
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
@@ -158,10 +138,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
emptyString := ""
|
||||
parentPathParameter := parameters.NewStringParameterWithDefault(parentPathKey, emptyString, "Relative parent document path to list subcollections from (e.g., 'users/userId'). If not provided, lists root collections. Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'")
|
||||
params := parameters.Parameters{parentPathParameter}
|
||||
@@ -91,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -103,9 +84,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -115,10 +94,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
var err error
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
@@ -130,14 +113,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// List subcollections of the specified document
|
||||
docRef := t.Client.Doc(parentPath)
|
||||
docRef := source.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = t.Client.Collections(ctx).GetAll()
|
||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
@@ -177,10 +160,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -52,12 +51,9 @@ var validOperators = map[string]bool{
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTemplateParseFailed = "failed to parse template: %w"
|
||||
errTemplateExecFailed = "failed to execute template: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
errSelectFieldParseFailed = "failed to parse select field: %w"
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -79,11 +75,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
// Config represents the configuration for the Firestore query tool
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -114,18 +105,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize creates a new Tool instance from the configuration
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Set default limit if not specified
|
||||
if cfg.Limit == "" {
|
||||
cfg.Limit = fmt.Sprintf("%d", defaultLimit)
|
||||
@@ -137,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -201,6 +179,11 @@ type QueryResponse struct {
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Process collection path with template substitution
|
||||
@@ -210,7 +193,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(collectionPath, paramsMap)
|
||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,8 +203,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := t.Client.Collection(collectionPath)
|
||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
@@ -239,7 +222,7 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
|
||||
}
|
||||
|
||||
// Convert simplified filter to Firestore filter
|
||||
if filter := t.convertToFirestoreFilter(simplifiedFilter); filter != nil {
|
||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
}
|
||||
@@ -280,12 +263,12 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
|
||||
}
|
||||
|
||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||
func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.EntityFilter {
|
||||
func (t Tool) convertToFirestoreFilter(source compatibleSource, filter SimplifiedFilter) firestoreapi.EntityFilter {
|
||||
// Handle AND filters
|
||||
if len(filter.And) > 0 {
|
||||
filters := make([]firestoreapi.EntityFilter, 0, len(filter.And))
|
||||
for _, f := range filter.And {
|
||||
if converted := t.convertToFirestoreFilter(f); converted != nil {
|
||||
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
|
||||
filters = append(filters, converted)
|
||||
}
|
||||
}
|
||||
@@ -299,7 +282,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
|
||||
if len(filter.Or) > 0 {
|
||||
filters := make([]firestoreapi.EntityFilter, 0, len(filter.Or))
|
||||
for _, f := range filter.Or {
|
||||
if converted := t.convertToFirestoreFilter(f); converted != nil {
|
||||
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
|
||||
filters = append(filters, converted)
|
||||
}
|
||||
}
|
||||
@@ -313,7 +296,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
|
||||
if filter.Field != "" && filter.Op != "" && filter.Value != nil {
|
||||
if validOperators[filter.Op] {
|
||||
// Convert the value using the Firestore native JSON converter
|
||||
convertedValue, err := util.JSONToFirestoreValue(filter.Value, t.Client)
|
||||
convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient())
|
||||
if err != nil {
|
||||
// If conversion fails, use the original value
|
||||
convertedValue = filter.Value
|
||||
@@ -525,10 +508,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -92,11 +91,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -116,18 +110,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize creates a new Tool instance from the configuration
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
params := createParameters()
|
||||
|
||||
@@ -137,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -199,9 +180,7 @@ var _ tools.Tool = Tool{}
|
||||
// Tool represents the Firestore query collection tool
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -266,6 +245,11 @@ type QueryResponse struct {
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
queryParams, err := t.parseQueryParameters(params)
|
||||
if err != nil {
|
||||
@@ -273,7 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(queryParams)
|
||||
query, err := t.buildQuery(source, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -396,8 +380,8 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := t.Client.Collection(params.CollectionPath)
|
||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
@@ -531,10 +515,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -52,11 +51,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -73,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
documentPathParameter := parameters.NewStringParameter(
|
||||
documentPathKey,
|
||||
@@ -134,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -146,9 +127,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -158,6 +137,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get document path
|
||||
@@ -200,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Get the document reference
|
||||
docRef := t.Client.Doc(documentPath)
|
||||
docRef := source.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestoreapi.WriteResult
|
||||
@@ -211,7 +195,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
|
||||
// Convert document data without delete markers
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -239,7 +223,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
// Update all fields in the document data (merge)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -314,10 +298,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -132,32 +132,6 @@ func TestConfig_Initialize(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "source not found",
|
||||
config: Config{
|
||||
Name: "test-update-document",
|
||||
Kind: "firestore-update-document",
|
||||
Source: "missing-source",
|
||||
Description: "Update a document",
|
||||
},
|
||||
sources: map[string]sources.Source{},
|
||||
wantErr: true,
|
||||
errMsg: "no source named \"missing-source\" configured",
|
||||
},
|
||||
{
|
||||
name: "incompatible source",
|
||||
config: Config{
|
||||
Name: "test-update-document",
|
||||
Kind: "firestore-update-document",
|
||||
Source: "wrong-source",
|
||||
Description: "Update a document",
|
||||
},
|
||||
sources: map[string]sources.Source{
|
||||
"wrong-source": &mockIncompatibleSource{},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "invalid source for \"firestore-update-document\" tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -464,14 +438,3 @@ func TestGetFieldValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockIncompatibleSource is a mock source that doesn't implement compatibleSource
|
||||
type mockIncompatibleSource struct{}
|
||||
|
||||
func (m *mockIncompatibleSource) SourceKind() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
@@ -53,11 +52,6 @@ type compatibleSource interface {
|
||||
GetProjectId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
params := createParameters()
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
@@ -94,8 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -117,10 +97,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -154,11 +131,16 @@ type ValidationResult struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get source parameter
|
||||
source, ok := mapParams[sourceKey].(string)
|
||||
if !ok || source == "" {
|
||||
sourceParam, ok := mapParams[sourceKey].(string)
|
||||
if !ok || sourceParam == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
@@ -168,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: source,
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -179,14 +161,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", t.ProjectId)
|
||||
response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, source)
|
||||
result := t.processValidationResponse(response, sourceParam)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -287,10 +269,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
HttpDefaultHeaders() map[string]string
|
||||
HttpBaseURL() string
|
||||
HttpQueryParams() map[string]string
|
||||
Client() *http.Client
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -81,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*httpsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind)
|
||||
}
|
||||
@@ -89,7 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Combine Source and Tool headers.
|
||||
// In case of conflict, Tool header overrides Source header
|
||||
combinedHeaders := make(map[string]string)
|
||||
maps.Copy(combinedHeaders, s.DefaultHeaders)
|
||||
maps.Copy(combinedHeaders, s.HttpDefaultHeaders())
|
||||
maps.Copy(combinedHeaders, cfg.Headers)
|
||||
|
||||
// Create a slice for all parameters
|
||||
@@ -113,14 +119,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
BaseURL: s.BaseURL,
|
||||
Headers: combinedHeaders,
|
||||
DefaultQueryParams: s.QueryParams,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Headers: combinedHeaders,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -129,12 +132,8 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
DefaultQueryParams map[string]string `yaml:"defaultQueryParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *http.Client
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -229,6 +228,11 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Calculate request body
|
||||
@@ -238,7 +242,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Calculate URL
|
||||
urlString, err := getURL(t.BaseURL, t.Path, t.PathParams, t.QueryParams, t.DefaultQueryParams, paramsMap)
|
||||
urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating path parameters: %s", err)
|
||||
}
|
||||
@@ -256,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := t.Client.Do(req)
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
@@ -295,10 +299,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
params := lookercommon.GetQueryParameters()
|
||||
|
||||
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this tile will exist")
|
||||
@@ -109,12 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -129,13 +119,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -148,6 +134,11 @@ var (
|
||||
)
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -167,12 +158,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
visConfig := paramsMap["vis_config"].(map[string]any)
|
||||
wq.VisConfig = &visConfig
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
qresp, err := sdk.CreateQuery(*wq, "id", t.ApiSettings)
|
||||
qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create query request: %w", err)
|
||||
}
|
||||
@@ -239,7 +230,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Fields: &fields,
|
||||
}
|
||||
|
||||
resp, err := sdk.CreateDashboardElement(req, t.ApiSettings)
|
||||
resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create dashboard element request: %w", err)
|
||||
}
|
||||
@@ -264,14 +255,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
|
||||
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this filter will exist")
|
||||
@@ -109,14 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Parameters: params,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -131,16 +119,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -148,6 +129,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -205,12 +191,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
req.Dimension = &dimension
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
resp, err := sdk.CreateDashboardFilter(req, "name", t.ApiSettings)
|
||||
resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create dashboard filter request: %s", err)
|
||||
}
|
||||
@@ -239,10 +225,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookerds "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -56,12 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetApiSettings() *rtl.ApiSettings
|
||||
GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
|
||||
GoogleCloudProject() string
|
||||
GoogleCloudLocation() string
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
// Structs for building the JSON payload
|
||||
@@ -124,11 +123,6 @@ type CAPayload struct {
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &lookerds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{lookerds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -155,7 +149,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
if s.GoogleCloudProject() == "" {
|
||||
@@ -196,16 +190,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
ApiSettings: s.GetApiSettings(),
|
||||
Project: s.GoogleCloudProject(),
|
||||
Location: s.GoogleCloudLocation(),
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
TokenSource: ts,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
TokenSource: ts,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -215,15 +204,10 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
ApiSettings *rtl.ApiSettings
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
AuthTokenHeaderName string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project string
|
||||
Location string
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -231,8 +215,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
@@ -253,16 +241,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
ler := make([]LookerExploreReference, 0)
|
||||
for _, er := range exploreReferences {
|
||||
ler = append(ler, LookerExploreReference{
|
||||
LookerInstanceUri: t.ApiSettings.BaseUrl,
|
||||
LookerInstanceUri: source.LookerApiSettings().BaseUrl,
|
||||
LookmlModel: er.(map[string]any)["model"].(string),
|
||||
Explore: er.(map[string]any)["explore"].(string),
|
||||
})
|
||||
}
|
||||
oauth_creds := OAuthCredentials{}
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
oauth_creds.Token = TokenBased{AccessToken: string(accessToken)}
|
||||
} else {
|
||||
oauth_creds.Secret = SecretBased{ClientId: t.ApiSettings.ClientId, ClientSecret: t.ApiSettings.ClientSecret}
|
||||
oauth_creds.Secret = SecretBased{ClientId: source.LookerApiSettings().ClientId, ClientSecret: source.LookerApiSettings().ClientSecret}
|
||||
}
|
||||
|
||||
lers := LookerExploreReferences{
|
||||
@@ -273,8 +261,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
projectID := source.GoogleCloudProject()
|
||||
location := source.GoogleCloudLocation()
|
||||
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location))
|
||||
|
||||
headers := map[string]string{
|
||||
@@ -315,12 +303,16 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
@@ -563,6 +555,10 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files")
|
||||
filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project")
|
||||
fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file")
|
||||
@@ -90,12 +84,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -110,13 +100,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -124,7 +110,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
@@ -148,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Content: fileContent,
|
||||
}
|
||||
|
||||
err = lookercommon.CreateProjectFile(sdk, projectId, req, t.ApiSettings)
|
||||
err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create_project_file request: %s", err)
|
||||
}
|
||||
@@ -172,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files")
|
||||
filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project")
|
||||
params := parameters.Parameters{projectIdParameter, filePathParameter}
|
||||
@@ -91,12 +85,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -111,13 +101,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -125,7 +111,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
@@ -140,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"])
|
||||
}
|
||||
|
||||
err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, t.ApiSettings)
|
||||
err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making delete_project_file request: %s", err)
|
||||
}
|
||||
@@ -164,14 +155,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
devModeParameter := parameters.NewBooleanParameterWithDefault("devMode", true, "Whether to set Dev Mode.")
|
||||
params := parameters.Parameters{devModeParameter}
|
||||
|
||||
@@ -89,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
ShowHiddenExplores: s.ShowHiddenExplores,
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -110,14 +99,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
ShowHiddenExplores bool
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -125,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -135,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
@@ -148,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
req := v4.WriteApiSession{
|
||||
WorkspaceId: &devModeString,
|
||||
}
|
||||
resp, err := sdk.UpdateSession(req, t.ApiSettings)
|
||||
resp, err := sdk.UpdateSession(req, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting/resetting dev mode: %w", err)
|
||||
}
|
||||
@@ -169,14 +158,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -46,6 +45,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
LookerSessionLength() int64
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -63,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
typeParameter := parameters.NewStringParameterWithDefault("type", "", "Type of Looker content to embed (ie. dashboards, looks, query-visualization)")
|
||||
idParameter := parameters.NewStringParameterWithDefault("id", "", "The ID of the content to embed.")
|
||||
params := parameters.Parameters{
|
||||
@@ -94,19 +89,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
SessionLength: s.SessionLength,
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -115,15 +105,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
SessionLength int64
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -131,6 +115,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -147,16 +136,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
contentId_ptr = nil
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
forceLogoutLogin := true
|
||||
|
||||
sessionLength := source.LookerSessionLength()
|
||||
req := v4.EmbedParams{
|
||||
TargetUrl: fmt.Sprintf("%s/embed/%s/%s", t.ApiSettings.BaseUrl, *embedType_ptr, *contentId_ptr),
|
||||
SessionLength: &t.SessionLength,
|
||||
TargetUrl: fmt.Sprintf("%s/embed/%s/%s", source.LookerApiSettings().BaseUrl, *embedType_ptr, *contentId_ptr),
|
||||
SessionLength: &sessionLength,
|
||||
ForceLogoutLogin: &forceLogoutLogin,
|
||||
}
|
||||
logger.ErrorContext(ctx, "Making request %v", req)
|
||||
@@ -181,14 +170,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
connParameter := parameters.NewStringParameter("conn", "The connection containing the databases.")
|
||||
params := parameters.Parameters{connParameter}
|
||||
|
||||
@@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -108,13 +98,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -122,17 +108,22 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
conn, ok := mapParams["conn"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
resp, err := sdk.ConnectionDatabases(conn, t.ApiSettings)
|
||||
resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making get_connection_databases request: %s", err)
|
||||
}
|
||||
@@ -153,14 +144,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user