mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-05-02 03:00:36 -04:00
feat: add auth_required to tools (#123)
Check if a tool invocation request contains required auth token.
This commit is contained in:
@@ -91,7 +91,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
for _, aS := range s.authSources {
|
||||
claims, err := aS.GetClaimsFromHeader(r.Header)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Failure getting claims from header: %w", err)
|
||||
err := fmt.Errorf("failure getting claims from header: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
@@ -102,6 +102,21 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
claimsFromAuth[aS.GetName()] = claims
|
||||
}
|
||||
|
||||
// Tool authorization check
|
||||
verifiedAuthSources := make([]string, len(claimsFromAuth))
|
||||
i := 0
|
||||
for k := range claimsFromAuth {
|
||||
verifiedAuthSources[i] = k
|
||||
i++
|
||||
}
|
||||
// Check if any of the specified auth sources is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthSources)
|
||||
if !isAuthorized {
|
||||
err := fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := render.DecodeJSON(r.Body, &data); err != nil {
|
||||
render.Status(r, http.StatusBadRequest)
|
||||
|
||||
@@ -52,6 +52,10 @@ func (t MockTool) Manifest() tools.Manifest {
|
||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||
}
|
||||
|
||||
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestToolsetEndpoint(t *testing.T) {
|
||||
// Set up resources to test against
|
||||
tool1 := MockTool{
|
||||
|
||||
@@ -41,12 +41,13 @@ var _ compatibleSource = &postgres.Source{}
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -71,24 +72,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func NewGenericTool(name, stmt, desc string, pool *pgxpool.Pool, parameters tools.Parameters) Tool {
|
||||
func NewGenericTool(name string, stmt string, authRequired []string, desc string, pool *pgxpool.Pool, parameters tools.Parameters) Tool {
|
||||
return Tool{
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
Pool: pool,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
AuthRequired: authRequired,
|
||||
Pool: pool,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,9 +99,10 @@ func NewGenericTool(name, stmt, desc string, pool *pgxpool.Pool, parameters tool
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
@@ -132,3 +136,7 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
@@ -53,11 +56,12 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
|
||||
@@ -39,12 +39,13 @@ var _ compatibleSource = &spannerdb.Source{}
|
||||
var compatibleSources = [...]string{spannerdb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -69,26 +70,28 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Client: s.SpannerClient(),
|
||||
dialect: s.DatabaseDialect(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.SpannerClient(),
|
||||
dialect: s.DatabaseDialect(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
|
||||
func NewGenericTool(name string, stmt string, authRequired []string, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
|
||||
return Tool{
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
Client: client,
|
||||
dialect: dialect,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
AuthRequired: authRequired,
|
||||
Client: client,
|
||||
dialect: dialect,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,9 +99,10 @@ func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect str
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *spanner.Client
|
||||
dialect string
|
||||
@@ -160,3 +164,7 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
@@ -27,6 +29,7 @@ type Tool interface {
|
||||
Invoke(ParamValues) (string, error)
|
||||
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
||||
Manifest() Manifest
|
||||
Authorized([]string) bool
|
||||
}
|
||||
|
||||
// Manifest is the representation of tools sent to Client SDKs.
|
||||
@@ -34,3 +37,17 @@ type Manifest struct {
|
||||
Description string `json:"description"`
|
||||
Parameters []ParameterManifest `json:"parameters"`
|
||||
}
|
||||
|
||||
// Helper function that returns if a tool invocation request is authorized
|
||||
func IsAuthorized(authRequiredSources []string, verifiedAuthSources []string) bool {
|
||||
if len(authRequiredSources) == 0 {
|
||||
// no authorization requirement
|
||||
return true
|
||||
}
|
||||
for _, a := range authRequiredSources {
|
||||
if slices.Contains(verifiedAuthSources, a) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user