feat: add auth_required to tools (#123)

Check if a tool invocation request contains required auth token.
This commit is contained in:
Wenxin Du
2024-12-16 22:41:13 -05:00
committed by GitHub
parent 380a6fbbd5
commit 3118104ae1
8 changed files with 110 additions and 54 deletions

View File

@@ -91,7 +91,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
for _, aS := range s.authSources {
claims, err := aS.GetClaimsFromHeader(r.Header)
if err != nil {
err := fmt.Errorf("Failure getting claims from header: %w", err)
err := fmt.Errorf("failure getting claims from header: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
@@ -102,6 +102,21 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
claimsFromAuth[aS.GetName()] = claims
}
// Tool authorization check
verifiedAuthSources := make([]string, len(claimsFromAuth))
i := 0
for k := range claimsFromAuth {
verifiedAuthSources[i] = k
i++
}
// Check if any of the specified auth sources is verified
isAuthorized := tool.Authorized(verifiedAuthSources)
if !isAuthorized {
err := fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
return
}
var data map[string]any
if err := render.DecodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)

View File

@@ -52,6 +52,10 @@ func (t MockTool) Manifest() tools.Manifest {
return tools.Manifest{Description: t.Description, Parameters: pMs}
}
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
return true
}
func TestToolsetEndpoint(t *testing.T) {
// Set up resources to test against
tool1 := MockTool{

View File

@@ -41,12 +41,13 @@ var _ compatibleSource = &postgres.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
}
// validate interface
@@ -71,24 +72,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
}
return t, nil
}
func NewGenericTool(name, stmt, desc string, pool *pgxpool.Pool, parameters tools.Parameters) Tool {
func NewGenericTool(name string, stmt string, authRequired []string, desc string, pool *pgxpool.Pool, parameters tools.Parameters) Tool {
return Tool{
Name: name,
Kind: ToolKind,
Statement: stmt,
Pool: pool,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
Name: name,
Kind: ToolKind,
Statement: stmt,
AuthRequired: authRequired,
Pool: pool,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
}
}
@@ -96,9 +99,10 @@ func NewGenericTool(name, stmt, desc string, pool *pgxpool.Pool, parameters tool
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Statement string
@@ -132,3 +136,7 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
}

View File

@@ -41,6 +41,9 @@ func TestParseFromYamlPostgres(t *testing.T) {
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
authRequired:
- my-google-auth-service
- other-auth-service
parameters:
- name: country
type: string
@@ -53,11 +56,12 @@ func TestParseFromYamlPostgres(t *testing.T) {
`,
want: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},

View File

@@ -39,12 +39,13 @@ var _ compatibleSource = &spannerdb.Source{}
var compatibleSources = [...]string{spannerdb.SourceKind}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
}
// validate interface
@@ -69,26 +70,28 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
Client: s.SpannerClient(),
dialect: s.DatabaseDialect(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Client: s.SpannerClient(),
dialect: s.DatabaseDialect(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
}
return t, nil
}
func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
func NewGenericTool(name string, stmt string, authRequired []string, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
return Tool{
Name: name,
Kind: ToolKind,
Statement: stmt,
Client: client,
dialect: dialect,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
Name: name,
Kind: ToolKind,
Statement: stmt,
AuthRequired: authRequired,
Client: client,
dialect: dialect,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
}
}
@@ -96,9 +99,10 @@ func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect str
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Client *spanner.Client
dialect string
@@ -160,3 +164,7 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
}

View File

@@ -15,6 +15,8 @@
package tools
import (
"slices"
"github.com/googleapis/genai-toolbox/internal/sources"
)
@@ -27,6 +29,7 @@ type Tool interface {
Invoke(ParamValues) (string, error)
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
Manifest() Manifest
Authorized([]string) bool
}
// Manifest is the representation of tools sent to Client SDKs.
@@ -34,3 +37,17 @@ type Manifest struct {
Description string `json:"description"`
Parameters []ParameterManifest `json:"parameters"`
}
// Helper function that returns if a tool invocation request is authorized
func IsAuthorized(authRequiredSources []string, verifiedAuthSources []string) bool {
if len(authRequiredSources) == 0 {
// no authorization requirement
return true
}
for _, a := range authRequiredSources {
if slices.Contains(verifiedAuthSources, a) {
return true
}
}
return false
}