mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
Merge branch 'main' into spanner-create-instance
This commit is contained in:
18
.github/renovate.json5
vendored
18
.github/renovate.json5
vendored
@@ -24,5 +24,23 @@
|
||||
],
|
||||
pinDigests: true,
|
||||
},
|
||||
{
|
||||
groupName: 'Go',
|
||||
matchManagers: [
|
||||
'gomod',
|
||||
],
|
||||
},
|
||||
{
|
||||
groupName: 'Node',
|
||||
matchManagers: [
|
||||
'npm',
|
||||
],
|
||||
},
|
||||
{
|
||||
groupName: 'Pip',
|
||||
matchManagers: [
|
||||
'pip_requirements',
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -236,6 +236,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
|
||||
4
go.mod
4
go.mod
@@ -22,7 +22,7 @@ require (
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/couchbase/gocb/v2 v2.11.1
|
||||
github.com/couchbase/tools-common/http v1.0.9
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
@@ -42,7 +42,7 @@ require (
|
||||
github.com/microsoft/go-mssqldb v1.9.3
|
||||
github.com/nakagami/firebirdsql v0.9.15
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4
|
||||
github.com/redis/go-redis/v9 v9.16.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/sijms/go-ora/v2 v2.9.0
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/thlib/go-timezone-local v0.0.7
|
||||
|
||||
8
go.sum
8
go.sum
@@ -822,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs=
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -1222,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
|
||||
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
|
||||
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
|
||||
@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
s.logger.DebugContext(ctx, errMsg.Error())
|
||||
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
if clientAuth {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
|
||||
@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
return !t.unauthorized
|
||||
}
|
||||
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||
// defaulted to false
|
||||
return t.requiresClientAuthrorization
|
||||
return t.requiresClientAuthrorization, nil
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
|
||||
return mcpManifest
|
||||
}
|
||||
|
||||
func (t MockTool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
// MockPrompt is used to mock prompts in tests
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -107,6 +107,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProjectID() string {
|
||||
return s.ProjectID
|
||||
}
|
||||
|
||||
func (s *Source) GetBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
|
||||
@@ -81,9 +81,9 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
BaseURL: "https://monitoring.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
baseURL: "https://monitoring.googleapis.com",
|
||||
client: client,
|
||||
userAgent: ua,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -92,9 +92,9 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -105,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) BaseURL() string {
|
||||
return s.baseURL
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) UserAgent() string {
|
||||
return s.userAgent
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
@@ -113,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
return s.client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
|
||||
@@ -110,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: &client,
|
||||
client: &client,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) HttpDefaultHeaders() map[string]string {
|
||||
return s.DefaultHeaders
|
||||
}
|
||||
|
||||
func (s *Source) HttpBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) HttpQueryParams() map[string]string {
|
||||
return s.QueryParams
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return strings.ToLower(s.UseClientOAuth) != "false"
|
||||
}
|
||||
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
|
||||
return google.DefaultTokenSource(ctx, scope)
|
||||
}
|
||||
|
||||
func (s *Source) LookerClient() *v4.LookerSDK {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenFields() bool {
|
||||
return s.ShowHiddenFields
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenModels() bool {
|
||||
return s.ShowHiddenModels
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenExplores() bool {
|
||||
return s.ShowHiddenExplores
|
||||
}
|
||||
|
||||
func (s *Source) LookerSessionLength() int64 {
|
||||
return s.SessionLength
|
||||
}
|
||||
|
||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||
if err != nil {
|
||||
|
||||
@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProject() string {
|
||||
return s.Project
|
||||
}
|
||||
|
||||
func (s *Source) GetLocation() string {
|
||||
return s.Location
|
||||
}
|
||||
|
||||
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-cluster"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-instance"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-user"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-clusters"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-clusters tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-clusters tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-instances"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-users"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-users tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-users tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-wait-for-operation"
|
||||
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
numParams := len(cfg.NLConfigParameters)
|
||||
quotedNameParts := make([]string, 0, numParams)
|
||||
placeholderParts := make([]string, 0, numParams)
|
||||
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Config: cfg,
|
||||
Parameters: cfg.NLConfigParameters,
|
||||
Statement: stmt,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool := source.PostgresPool()
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
allParamValues := make([]any, len(sliceParams)+1)
|
||||
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
|
||||
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
allParamValues[i+2] = fmt.Sprintf("%s", param)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
|
||||
results, err := pool.Query(ctx, t.Statement, allParamValues...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
|
||||
}
|
||||
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(inputData, ".")
|
||||
var projectID, datasetID string
|
||||
switch len(parts) {
|
||||
case 3: // project.dataset.table
|
||||
projectID, datasetID = parts[0], parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID, datasetID = t.Client.Project(), parts[0]
|
||||
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
||||
}
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
||||
}
|
||||
}
|
||||
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -105,11 +104,6 @@ type CAPayload struct {
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
var bigQueryTokenSourceWithScope oauth2.TokenSource
|
||||
if !s.UseClientAuthorization() {
|
||||
ctx := context.Background()
|
||||
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
bigQueryTokenSourceWithScope = ts
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Project: s.BigQueryProject(),
|
||||
Location: s.BigQueryLocation(),
|
||||
Parameters: params,
|
||||
Client: s.BigQueryClient(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
TokenSource: bigQueryTokenSourceWithScope,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project string
|
||||
Location string
|
||||
Client *bigqueryapi.Client
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
MaxQueryResultRows int
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
if t.TokenSource == nil {
|
||||
if tokenSource == nil {
|
||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
||||
}
|
||||
token, err := t.TokenSource.Token()
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
||||
}
|
||||
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
for _, tableRef := range tableRefs {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
projectID := source.BigQueryProject()
|
||||
location := source.BigQueryLocation()
|
||||
if location == "" {
|
||||
location = "us"
|
||||
}
|
||||
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
|
||||
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||
}
|
||||
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -60,11 +60,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var sqlDescriptionBuilder strings.Builder
|
||||
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
WriteMode: s.BigQueryWriteMode(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
WriteMode string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
var session *bigqueryds.Session
|
||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||
session, err = t.SessionProvider(ctx)
|
||||
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
||||
session, err = source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
}
|
||||
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
switch t.WriteMode {
|
||||
switch source.BigQueryWriteMode() {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
||||
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
} else if statementType != "SELECT" {
|
||||
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
||||
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
||||
if parseErr != nil {
|
||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
parts := strings.Split(tableID, ".")
|
||||
if len(parts) == 3 {
|
||||
projectID, datasetID := parts[0], parts[1]
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||
}
|
||||
}
|
||||
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
historyData, ok := paramsMap["history_data"].(string)
|
||||
if !ok {
|
||||
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var historyDataSource string
|
||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(historyData, ".")
|
||||
var projectID, datasetID string
|
||||
|
||||
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
projectID = parts[0]
|
||||
datasetID = parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID = t.Client.Project()
|
||||
projectID = source.BigQueryClient().Project()
|
||||
datasetID = parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
||||
}
|
||||
}
|
||||
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -54,11 +54,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -52,11 +52,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var projectParameter parameters.Parameter
|
||||
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
return t.AllowedDatasets, nil
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
return source.BigQueryAllowedDatasets(), nil
|
||||
}
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Get the Dataplex client using the method from the source
|
||||
makeCatalogClient := s.MakeDataplexCatalogClient()
|
||||
|
||||
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
|
||||
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
|
||||
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
|
||||
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
MakeCatalogClient: makeCatalogClient,
|
||||
ProjectID: s.BigQueryProject(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
UseClientOAuth bool
|
||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
|
||||
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
prompt, _ := paramsMap["prompt"].(string)
|
||||
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
|
||||
PageSize: pageSize,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
|
||||
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
|
||||
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
it := catalogClient.SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
|
||||
}
|
||||
|
||||
var results []Response
|
||||
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||
|
||||
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
query.Location = bqClient.Location
|
||||
|
||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||
if t.SessionProvider != nil {
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if source.BigQuerySession() != nil {
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"cloud.google.com/go/bigtable"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,11 +45,6 @@ type compatibleSource interface {
|
||||
BigtableClient() *bigtable.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigtabledb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigtabledb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigtable.Client
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
ps, err := t.Client.PrepareStatement(
|
||||
ps, err := source.BigtableClient().PrepareStatement(
|
||||
ctx,
|
||||
newStatement,
|
||||
mapParamsType,
|
||||
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,10 +45,6 @@ type compatibleSource interface {
|
||||
CassandraSession() *gocql.Session
|
||||
}
|
||||
|
||||
var _ compatibleSource = &cassandra.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cassandra.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,20 +56,15 @@ type Config struct {
|
||||
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
t := Tool{
|
||||
Config: c,
|
||||
AllParams: allParameters,
|
||||
Session: s.CassandraSession(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Session *gocql.Session
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// RequiresClientAuthorization implements tools.Tool.
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const executeSQLKind string = "clickhouse-execute-sql"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
|
||||
params := parameters.Parameters{sqlParameter}
|
||||
|
||||
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listDatabasesKind string = "clickhouse-list-databases"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query to list all databases
|
||||
query := "SHOW DATABASES"
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-databases",
|
||||
Kind: listDatabasesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list databases tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listTablesKind string = "clickhouse-list-tables"
|
||||
const databaseKey string = "database"
|
||||
|
||||
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
|
||||
}
|
||||
|
||||
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
|
||||
params := parameters.Parameters{databaseParameter}
|
||||
|
||||
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
database, ok := mapParams[databaseKey].(string)
|
||||
if !ok {
|
||||
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Query to list all tables in the specified database
|
||||
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-tables",
|
||||
Kind: listTablesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list tables tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListTables(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,21 +25,15 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const sqlKind string = "clickhouse-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(sqlKind, newSQLConfig) {
|
||||
if !tools.Register(sqlKind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeMissingSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
sources := map[string]sources.Source{}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing source, got nil")
|
||||
}
|
||||
|
||||
expectedErr := `no source named "missing-source" configured`
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
|
||||
type mockIncompatibleSource struct{}
|
||||
|
||||
func (m *mockIncompatibleSource) SourceKind() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "incompatible-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
mockSource := &mockIncompatibleSource{}
|
||||
|
||||
sources := map[string]sources.Source{
|
||||
"incompatible-source": mockSource,
|
||||
}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for incompatible source, got nil")
|
||||
}
|
||||
|
||||
if err.Error() == "" {
|
||||
t.Error("Expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolManifest(t *testing.T) {
|
||||
tool := Tool{
|
||||
manifest: tools.Manifest{
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetProjectID() string
|
||||
GetBaseURL() string
|
||||
UseClientAuthorization() bool
|
||||
GetClient(context.Context, string) (*http.Client, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudgdasrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
@@ -87,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -99,7 +92,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters
|
||||
Source *cloudgdasrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -110,6 +102,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
if !ok {
|
||||
@@ -118,11 +115,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
|
||||
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
@@ -138,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Parse the access token if provided
|
||||
var tokenStr string
|
||||
if t.RequiresClientAuthorization(resourceMgr) {
|
||||
if source.UseClientAuthorization() {
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
@@ -146,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, tokenStr)
|
||||
client, err := source.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
@@ -196,10 +193,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
@@ -172,9 +173,8 @@ func TestInitialize(t *testing.T) {
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
expectErr bool
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
}{
|
||||
{
|
||||
desc: "successful initialization",
|
||||
@@ -185,29 +185,6 @@ func TestInitialize(t *testing.T) {
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "incompatible source kind",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "incompatible-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -219,16 +196,11 @@ func TestInitialize(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool, err := tc.cfg.Initialize(srcs)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
if err != nil {
|
||||
t.Fatalf("did not expect an error but got: %v", err)
|
||||
}
|
||||
if !tc.expectErr {
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
}
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -361,8 +333,10 @@ func TestInvoke(t *testing.T) {
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
|
||||
params := parameters.Parameters{urlParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, ok := params.AsMap()[pageURLKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The t.Service object holds a client with the default credentials.
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
|
||||
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
||||
var opts []googleapi.CallOption
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -78,11 +78,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -59,11 +59,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).")
|
||||
@@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -122,15 +110,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
@@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -61,11 +61,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -127,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
@@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -68,11 +68,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -141,15 +129,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
||||
@@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -65,11 +65,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -137,15 +125,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
@@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -63,11 +63,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -93,7 +88,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -113,17 +108,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -133,15 +121,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -149,19 +131,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -171,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
|
||||
@@ -211,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BaseURL() string
|
||||
Client() *http.Client
|
||||
UserAgent() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -60,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudmonitoringsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters internally instead of from the config file.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true),
|
||||
@@ -83,9 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
BaseURL: s.BaseURL,
|
||||
UserAgent: s.UserAgent,
|
||||
Client: s.Client,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -97,9 +87,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
UserAgent string
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -109,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
projectID, ok := paramsMap["projectId"].(string)
|
||||
if !ok {
|
||||
@@ -119,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID)
|
||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -130,9 +122,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
q.Add("query", query)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
req.Header.Set("User-Agent", t.UserAgent)
|
||||
req.Header.Set("User-Agent", source.UserAgent())
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -81,22 +81,6 @@ func TestInitialize(t *testing.T) {
|
||||
AuthRequired: []string{"google-auth-service"},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Error: source not found",
|
||||
cfg: cloudmonitoring.Config{
|
||||
Name: "test-tool",
|
||||
Source: "non-existent-source",
|
||||
},
|
||||
wantErr: `no source named "non-existent-source" configured`,
|
||||
},
|
||||
{
|
||||
desc: "Error: incompatible source kind",
|
||||
cfg: cloudmonitoring.Config{
|
||||
Name: "test-tool",
|
||||
Source: "incompatible-source",
|
||||
},
|
||||
wantErr: "invalid source for \"cloud-monitoring-query-prometheus\" tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the clone-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the clone-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -120,6 +123,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -156,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
CloneContext: cloneContext,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,10 +196,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-database tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -103,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-database tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -115,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -136,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Instance: instance,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -149,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
user.Password = password
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,10 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-get-instance"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the get-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -65,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("projectId", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -92,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -114,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
projectId, ok := paramsMap["projectId"].(string)
|
||||
@@ -125,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'instanceId' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,10 +167,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-list-databases"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the list-databases tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladminsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -91,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Source *cloudsqladminsrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -113,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -124,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'instance' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -176,10 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-list-instances"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the list-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladminsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -90,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -101,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
source *cloudsqladminsrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -112,6 +116,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -119,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
}
|
||||
|
||||
service, err := t.source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,10 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-wait-for-operation"
|
||||
@@ -87,6 +87,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -118,12 +124,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -177,7 +183,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -191,17 +196,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -210,6 +213,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -221,7 +229,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -267,7 +275,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("could not unmarshal operation: %w", err)
|
||||
}
|
||||
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(data); ok {
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok {
|
||||
return msg, nil
|
||||
}
|
||||
return string(opBytes), nil
|
||||
@@ -305,11 +313,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) {
|
||||
func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
return "", false
|
||||
@@ -329,7 +341,7 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
|
||||
instance := matches[2]
|
||||
database := matches[3]
|
||||
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), project, instance)
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance)
|
||||
if err != nil {
|
||||
fmt.Printf("error fetching instance data: %v\n", err)
|
||||
return "", false
|
||||
@@ -385,8 +397,8 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
|
||||
service, err := t.Source.GetService(ctx, "")
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) {
|
||||
service, err := source.GetService(ctx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -408,6 +420,6 @@ func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *cloudsqladmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Project: project,
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
@@ -43,6 +42,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// Config defines the configuration for the precheck-upgrade tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -62,15 +66,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(*cloudsqladmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind)
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameter("project", "The project ID"),
|
||||
parameters.NewStringParameter("instance", "The name of the instance to check"),
|
||||
@@ -88,28 +83,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the precheck-upgrade tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Source *cloudsqladmin.Source
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Config
|
||||
}
|
||||
|
||||
// PreCheckResultItem holds the details of a single check result.
|
||||
@@ -146,6 +132,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -162,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client from source: %w", err)
|
||||
}
|
||||
@@ -234,10 +225,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/couchbase/gocb/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
CouchbaseQueryScanConsistency() uint
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &couchbase.Source{}
|
||||
|
||||
var compatibleSources = [...]string{couchbase.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -92,12 +74,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Scope: s.CouchbaseScope(),
|
||||
QueryScanConsistency: s.CouchbaseQueryScanConsistency(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -107,12 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Scope *gocb.Scope
|
||||
QueryScanConsistency uint
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -120,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namedParamsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
|
||||
if err != nil {
|
||||
@@ -130,8 +112,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{
|
||||
ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency),
|
||||
results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{
|
||||
ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()),
|
||||
NamedParameters: newParams.AsMap(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -166,10 +148,10 @@ func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -118,10 +118,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
viewDesc := `
|
||||
## Argument: view
|
||||
|
||||
@@ -104,9 +87,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -119,10 +101,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,6 +111,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
@@ -153,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := t.CatalogClient.LookupEntry(ctx, req)
|
||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,10 +165,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -49,11 +48,6 @@ type compatibleSource interface {
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,17 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
query := parameters.NewStringParameter("query", "The query against which aspect type should be matched.")
|
||||
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of returned aspect types in the search page.")
|
||||
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
|
||||
@@ -89,10 +72,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
ProjectID: s.ProjectID(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -105,11 +86,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -117,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the tool with the provided parameters
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
@@ -126,16 +110,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := t.CatalogClient.SearchEntries(ctx, req)
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
@@ -155,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := t.CatalogClient.GetAspectType(ctx, getAspectTypeReq)
|
||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
@@ -192,10 +176,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
query := parameters.NewStringParameter("query", "The query against which entries in scope should be matched.")
|
||||
pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.")
|
||||
orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
|
||||
@@ -88,10 +71,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
ProjectID: s.ProjectID(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -104,11 +85,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -116,6 +95,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
@@ -123,15 +107,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
it := t.CatalogClient.SearchEntries(ctx, req)
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
@@ -163,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -46,11 +46,6 @@ type compatibleSource interface {
|
||||
DgraphClient() *dgraph.DgraphClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dgraph.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dgraph.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,26 +66,13 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
DgraphClient: s.DgraphClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -100,9 +82,8 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
DgraphClient *dgraph.DgraphClient
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -110,9 +91,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMapWithDollarPrefix()
|
||||
|
||||
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,10 +134,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -43,10 +43,6 @@ type compatibleSource interface {
|
||||
ElasticsearchClient() es.EsClient
|
||||
}
|
||||
|
||||
var _ compatibleSource = &es.Source{}
|
||||
|
||||
var compatibleSources = [...]string{es.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -77,29 +73,15 @@ type Tool struct {
|
||||
Config
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
EsClient es.EsClient
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
src, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := src.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(c.Name, c.Description, c.AuthRequired, c.Parameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: c,
|
||||
EsClient: s.ElasticsearchClient(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: c.Parameters.Manifest(), AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -120,6 +102,11 @@ type esqlResult struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
if t.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Timeout)*time.Second)
|
||||
@@ -164,8 +151,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Body: bytes.NewReader(body),
|
||||
Format: t.Format,
|
||||
FilterPath: []string{"columns", "values"},
|
||||
Instrument: t.EsClient.InstrumentationEnabled(),
|
||||
}.Do(ctx, t.EsClient)
|
||||
Instrument: source.ElasticsearchClient().InstrumentationEnabled(),
|
||||
}.Do(ctx, source.ElasticsearchClient())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -230,10 +217,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -47,10 +46,6 @@ type compatibleSource interface {
|
||||
FirebirdDB() *sql.DB
|
||||
}
|
||||
|
||||
var _ compatibleSource = &firebird.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firebird.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -66,16 +61,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.")
|
||||
params := parameters.Parameters{sqlParameter}
|
||||
|
||||
@@ -84,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Db: s.FirebirdDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -95,9 +79,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Db *sql.DB
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -107,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
@@ -120,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||
|
||||
rows, err := t.Db.QueryContext(ctx, sql)
|
||||
rows, err := source.FirebirdDB().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -180,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
FirebirdDB() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firebird.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firebird.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Db: s.FirebirdDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -106,9 +87,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Db *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -142,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := t.Db.QueryContext(ctx, statement, namedArgs...)
|
||||
rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -204,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -50,11 +49,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
collectionPathParameter := parameters.NewStringParameter(
|
||||
collectionPathKey,
|
||||
@@ -124,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -136,9 +117,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -148,6 +127,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
@@ -169,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Convert the document data from JSON format to Firestore format
|
||||
// The client is passed to handle referenceValue types
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -181,7 +165,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Get the collection reference
|
||||
collection := t.Client.Collection(collectionPath)
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
@@ -221,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to delete from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
|
||||
params := parameters.Parameters{documentPathsParameter}
|
||||
|
||||
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
@@ -143,14 +127,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := t.Client.BulkWriter(ctx)
|
||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := t.Client.Doc(path)
|
||||
docRef := source.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
@@ -198,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to retrieve from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path"))
|
||||
params := parameters.Parameters{documentPathsParameter}
|
||||
|
||||
@@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -102,9 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
@@ -145,11 +129,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = t.Client.Doc(path)
|
||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := t.Client.GetAll(ctx, docRefs)
|
||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
@@ -190,10 +174,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
GetDatabaseId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// No parameters needed for this tool
|
||||
params := parameters.Parameters{}
|
||||
|
||||
@@ -90,9 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
DatabaseId: s.GetDatabaseId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -104,11 +83,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
DatabaseId string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,19 +93,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", t.ProjectId, t.DatabaseId)
|
||||
release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", t.ProjectId, t.DatabaseId)
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
@@ -158,10 +138,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -48,11 +47,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
emptyString := ""
|
||||
parentPathParameter := parameters.NewStringParameterWithDefault(parentPathKey, emptyString, "Relative parent document path to list subcollections from (e.g., 'users/userId'). If not provided, lists root collections. Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'")
|
||||
params := parameters.Parameters{parentPathParameter}
|
||||
@@ -91,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -103,9 +84,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -115,10 +94,14 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
var err error
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
@@ -130,14 +113,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// List subcollections of the specified document
|
||||
docRef := t.Client.Doc(parentPath)
|
||||
docRef := source.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = t.Client.Collections(ctx).GetAll()
|
||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
@@ -177,10 +160,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -52,12 +51,9 @@ var validOperators = map[string]bool{
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTemplateParseFailed = "failed to parse template: %w"
|
||||
errTemplateExecFailed = "failed to execute template: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
errSelectFieldParseFailed = "failed to parse select field: %w"
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -79,11 +75,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
// Config represents the configuration for the Firestore query tool
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -114,18 +105,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize creates a new Tool instance from the configuration
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Set default limit if not specified
|
||||
if cfg.Limit == "" {
|
||||
cfg.Limit = fmt.Sprintf("%d", defaultLimit)
|
||||
@@ -137,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -201,6 +179,11 @@ type QueryResponse struct {
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Process collection path with template substitution
|
||||
@@ -210,7 +193,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(collectionPath, paramsMap)
|
||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,8 +203,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := t.Client.Collection(collectionPath)
|
||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
@@ -239,7 +222,7 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
|
||||
}
|
||||
|
||||
// Convert simplified filter to Firestore filter
|
||||
if filter := t.convertToFirestoreFilter(simplifiedFilter); filter != nil {
|
||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
}
|
||||
@@ -280,12 +263,12 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto
|
||||
}
|
||||
|
||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||
func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.EntityFilter {
|
||||
func (t Tool) convertToFirestoreFilter(source compatibleSource, filter SimplifiedFilter) firestoreapi.EntityFilter {
|
||||
// Handle AND filters
|
||||
if len(filter.And) > 0 {
|
||||
filters := make([]firestoreapi.EntityFilter, 0, len(filter.And))
|
||||
for _, f := range filter.And {
|
||||
if converted := t.convertToFirestoreFilter(f); converted != nil {
|
||||
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
|
||||
filters = append(filters, converted)
|
||||
}
|
||||
}
|
||||
@@ -299,7 +282,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
|
||||
if len(filter.Or) > 0 {
|
||||
filters := make([]firestoreapi.EntityFilter, 0, len(filter.Or))
|
||||
for _, f := range filter.Or {
|
||||
if converted := t.convertToFirestoreFilter(f); converted != nil {
|
||||
if converted := t.convertToFirestoreFilter(source, f); converted != nil {
|
||||
filters = append(filters, converted)
|
||||
}
|
||||
}
|
||||
@@ -313,7 +296,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent
|
||||
if filter.Field != "" && filter.Op != "" && filter.Value != nil {
|
||||
if validOperators[filter.Op] {
|
||||
// Convert the value using the Firestore native JSON converter
|
||||
convertedValue, err := util.JSONToFirestoreValue(filter.Value, t.Client)
|
||||
convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient())
|
||||
if err != nil {
|
||||
// If conversion fails, use the original value
|
||||
convertedValue = filter.Value
|
||||
@@ -525,10 +508,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -92,11 +91,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -116,18 +110,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
// Initialize creates a new Tool instance from the configuration
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
params := createParameters()
|
||||
|
||||
@@ -137,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -199,9 +180,7 @@ var _ tools.Tool = Tool{}
|
||||
// Tool represents the Firestore query collection tool
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -266,6 +245,11 @@ type QueryResponse struct {
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
queryParams, err := t.parseQueryParameters(params)
|
||||
if err != nil {
|
||||
@@ -273,7 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(queryParams)
|
||||
query, err := t.buildQuery(source, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -396,8 +380,8 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := t.Client.Collection(params.CollectionPath)
|
||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
@@ -531,10 +515,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -52,11 +51,6 @@ type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -73,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
documentPathParameter := parameters.NewStringParameter(
|
||||
documentPathKey,
|
||||
@@ -134,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -146,9 +127,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -158,6 +137,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get document path
|
||||
@@ -200,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Get the document reference
|
||||
docRef := t.Client.Doc(documentPath)
|
||||
docRef := source.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestoreapi.WriteResult
|
||||
@@ -211,7 +195,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
|
||||
// Convert document data without delete markers
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -239,7 +223,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
// Update all fields in the document data (merge)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
@@ -314,10 +298,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -132,32 +132,6 @@ func TestConfig_Initialize(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "source not found",
|
||||
config: Config{
|
||||
Name: "test-update-document",
|
||||
Kind: "firestore-update-document",
|
||||
Source: "missing-source",
|
||||
Description: "Update a document",
|
||||
},
|
||||
sources: map[string]sources.Source{},
|
||||
wantErr: true,
|
||||
errMsg: "no source named \"missing-source\" configured",
|
||||
},
|
||||
{
|
||||
name: "incompatible source",
|
||||
config: Config{
|
||||
Name: "test-update-document",
|
||||
Kind: "firestore-update-document",
|
||||
Source: "wrong-source",
|
||||
Description: "Update a document",
|
||||
},
|
||||
sources: map[string]sources.Source{
|
||||
"wrong-source": &mockIncompatibleSource{},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "invalid source for \"firestore-update-document\" tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -464,14 +438,3 @@ func TestGetFieldValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockIncompatibleSource is a mock source that doesn't implement compatibleSource
|
||||
type mockIncompatibleSource struct{}
|
||||
|
||||
func (m *mockIncompatibleSource) SourceKind() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
@@ -53,11 +52,6 @@ type compatibleSource interface {
|
||||
GetProjectId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
params := createParameters()
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
@@ -94,8 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -117,10 +97,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -154,11 +131,16 @@ type ValidationResult struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get source parameter
|
||||
source, ok := mapParams[sourceKey].(string)
|
||||
if !ok || source == "" {
|
||||
sourceParam, ok := mapParams[sourceKey].(string)
|
||||
if !ok || sourceParam == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
@@ -168,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: source,
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -179,14 +161,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", t.ProjectId)
|
||||
response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, source)
|
||||
result := t.processValidationResponse(response, sourceParam)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -287,10 +269,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
HttpDefaultHeaders() map[string]string
|
||||
HttpBaseURL() string
|
||||
HttpQueryParams() map[string]string
|
||||
Client() *http.Client
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -81,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*httpsrc.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind)
|
||||
}
|
||||
@@ -89,7 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Combine Source and Tool headers.
|
||||
// In case of conflict, Tool header overrides Source header
|
||||
combinedHeaders := make(map[string]string)
|
||||
maps.Copy(combinedHeaders, s.DefaultHeaders)
|
||||
maps.Copy(combinedHeaders, s.HttpDefaultHeaders())
|
||||
maps.Copy(combinedHeaders, cfg.Headers)
|
||||
|
||||
// Create a slice for all parameters
|
||||
@@ -113,14 +119,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
BaseURL: s.BaseURL,
|
||||
Headers: combinedHeaders,
|
||||
DefaultQueryParams: s.QueryParams,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Headers: combinedHeaders,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -129,12 +132,8 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
DefaultQueryParams map[string]string `yaml:"defaultQueryParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *http.Client
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -229,6 +228,11 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Calculate request body
|
||||
@@ -238,7 +242,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Calculate URL
|
||||
urlString, err := getURL(t.BaseURL, t.Path, t.PathParams, t.QueryParams, t.DefaultQueryParams, paramsMap)
|
||||
urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating path parameters: %s", err)
|
||||
}
|
||||
@@ -256,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := t.Client.Do(req)
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
@@ -295,10 +299,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
params := lookercommon.GetQueryParameters()
|
||||
|
||||
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this tile will exist")
|
||||
@@ -109,12 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -129,13 +119,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -148,6 +134,11 @@ var (
|
||||
)
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -167,12 +158,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
visConfig := paramsMap["vis_config"].(map[string]any)
|
||||
wq.VisConfig = &visConfig
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
qresp, err := sdk.CreateQuery(*wq, "id", t.ApiSettings)
|
||||
qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create query request: %w", err)
|
||||
}
|
||||
@@ -239,7 +230,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Fields: &fields,
|
||||
}
|
||||
|
||||
resp, err := sdk.CreateDashboardElement(req, t.ApiSettings)
|
||||
resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create dashboard element request: %w", err)
|
||||
}
|
||||
@@ -264,14 +255,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
|
||||
dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this filter will exist")
|
||||
@@ -109,14 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Parameters: params,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -131,16 +119,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -148,6 +129,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
@@ -205,12 +191,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
req.Dimension = &dimension
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
resp, err := sdk.CreateDashboardFilter(req, "name", t.ApiSettings)
|
||||
resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create dashboard filter request: %s", err)
|
||||
}
|
||||
@@ -239,10 +225,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookerds "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -56,12 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetApiSettings() *rtl.ApiSettings
|
||||
GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
|
||||
GoogleCloudProject() string
|
||||
GoogleCloudLocation() string
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
// Structs for building the JSON payload
|
||||
@@ -124,11 +123,6 @@ type CAPayload struct {
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &lookerds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{lookerds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -155,7 +149,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
if s.GoogleCloudProject() == "" {
|
||||
@@ -196,16 +190,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
ApiSettings: s.GetApiSettings(),
|
||||
Project: s.GoogleCloudProject(),
|
||||
Location: s.GoogleCloudLocation(),
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
TokenSource: ts,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
TokenSource: ts,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -215,15 +204,10 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
ApiSettings *rtl.ApiSettings
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
AuthTokenHeaderName string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project string
|
||||
Location string
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -231,8 +215,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
@@ -253,16 +241,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
ler := make([]LookerExploreReference, 0)
|
||||
for _, er := range exploreReferences {
|
||||
ler = append(ler, LookerExploreReference{
|
||||
LookerInstanceUri: t.ApiSettings.BaseUrl,
|
||||
LookerInstanceUri: source.LookerApiSettings().BaseUrl,
|
||||
LookmlModel: er.(map[string]any)["model"].(string),
|
||||
Explore: er.(map[string]any)["explore"].(string),
|
||||
})
|
||||
}
|
||||
oauth_creds := OAuthCredentials{}
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
oauth_creds.Token = TokenBased{AccessToken: string(accessToken)}
|
||||
} else {
|
||||
oauth_creds.Secret = SecretBased{ClientId: t.ApiSettings.ClientId, ClientSecret: t.ApiSettings.ClientSecret}
|
||||
oauth_creds.Secret = SecretBased{ClientId: source.LookerApiSettings().ClientId, ClientSecret: source.LookerApiSettings().ClientSecret}
|
||||
}
|
||||
|
||||
lers := LookerExploreReferences{
|
||||
@@ -273,8 +261,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
projectID := source.GoogleCloudProject()
|
||||
location := source.GoogleCloudLocation()
|
||||
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location))
|
||||
|
||||
headers := map[string]string{
|
||||
@@ -315,12 +303,16 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
@@ -563,6 +555,10 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*lookersrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind)
|
||||
}
|
||||
|
||||
projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files")
|
||||
filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project")
|
||||
fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file")
|
||||
@@ -90,12 +84,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
AuthTokenHeaderName: s.GetAuthTokenHeaderName(),
|
||||
Client: s.Client,
|
||||
ApiSettings: s.ApiSettings,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -110,13 +100,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool
|
||||
AuthTokenHeaderName string
|
||||
Client *v4.LookerSDK
|
||||
ApiSettings *rtl.ApiSettings
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -124,7 +110,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
@@ -148,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Content: fileContent,
|
||||
}
|
||||
|
||||
err = lookercommon.CreateProjectFile(sdk, projectId, req, t.ApiSettings)
|
||||
err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making create_project_file request: %s", err)
|
||||
}
|
||||
@@ -172,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return t.AuthTokenHeaderName
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return source.GetAuthTokenHeaderName(), nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user