mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-17 18:41:28 -05:00
Currently, we are throwing 401 error immediately after auth token verification failure. This is not expected in the following situations: 1. Non-auth tool invocation with auth token that is invalid. 2. Auth tool invocation with all the required auth token, but the header contains extra non-required token that is invalid These requests should pass the authorization check but fail under the current implementation. Change made in this PR: 1. Do not throw error immediately after auth token verification failure. Instead only log it and continue to the next header iteration. 2. In the parseParams() method, if an auth parameter is missing, we should error with the message telling the user that either the auth header is missing or is invalid.
540 lines
14 KiB
Go
540 lines
14 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package tools
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
)
|
|
|
|
const (
|
|
typeString = "string"
|
|
typeInt = "integer"
|
|
typeFloat = "float"
|
|
typeBool = "boolean"
|
|
typeArray = "array"
|
|
)
|
|
|
|
// ParamValues is an ordered list of ParamValue
|
|
type ParamValues []ParamValue
|
|
|
|
// ParamValue represents the parameter's name and value.
|
|
type ParamValue struct {
|
|
Name string
|
|
Value any
|
|
}
|
|
|
|
// AsSlice returns a slice of the Param's values (in order).
|
|
func (p ParamValues) AsSlice() []any {
|
|
params := []any{}
|
|
|
|
for _, p := range p {
|
|
params = append(params, p.Value)
|
|
}
|
|
return params
|
|
}
|
|
|
|
// AsMap returns a map of ParamValue's names to values.
|
|
func (p ParamValues) AsMap() map[string]interface{} {
|
|
params := make(map[string]interface{})
|
|
for _, p := range p {
|
|
params[p.Name] = p.Value
|
|
}
|
|
return params
|
|
}
|
|
|
|
// AsReversedMap returns a map of ParamValue's values to names.
|
|
func (p ParamValues) AsReversedMap() map[any]string {
|
|
params := make(map[any]string)
|
|
for _, p := range p {
|
|
params[p.Value] = p.Name
|
|
}
|
|
return params
|
|
}
|
|
|
|
// AsMapByOrderedKeys returns a map of a key's position to it's value, as neccesary for Spanner PSQL.
|
|
// Example { $1 -> "value1", $2 -> "value2" }
|
|
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
|
|
params := make(map[string]interface{})
|
|
|
|
for i, p := range p {
|
|
key := fmt.Sprintf("p%d", i+1)
|
|
params[key] = p.Value
|
|
}
|
|
return params
|
|
}
|
|
|
|
func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) {
|
|
// parse a parameter from claims using its specified auth sources
|
|
for _, a := range paramAuthSources {
|
|
claims, ok := claimsMap[a.Name]
|
|
if !ok {
|
|
// not validated for this authsource, skip to the next one
|
|
continue
|
|
}
|
|
v, ok := claims[a.Field]
|
|
if !ok {
|
|
// claims do not contain specified field
|
|
return nil, fmt.Errorf("no field named %s in claims", a.Field)
|
|
}
|
|
return v, nil
|
|
}
|
|
return nil, fmt.Errorf("missing or invalid authentication header")
|
|
}
|
|
|
|
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
|
|
func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[string]any) (ParamValues, error) {
|
|
params := make([]ParamValue, 0, len(ps))
|
|
for _, p := range ps {
|
|
var v any
|
|
paramAuthSources := p.GetAuthSources()
|
|
name := p.GetName()
|
|
if paramAuthSources == nil {
|
|
// parse non auth-required parameter
|
|
var ok bool
|
|
v, ok = data[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("parameter %q is required", name)
|
|
}
|
|
} else {
|
|
// parse authenticated parameter
|
|
var err error
|
|
v, err = parseFromAuthSource(paramAuthSources, claimsMap)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err)
|
|
}
|
|
}
|
|
newV, err := p.Parse(v)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse value for %q: %w", name, err)
|
|
}
|
|
params = append(params, ParamValue{Name: name, Value: newV})
|
|
}
|
|
return params, nil
|
|
}
|
|
|
|
type Parameter interface {
|
|
// Note: It's typically not idiomatic to include "Get" in the function name,
|
|
// but this is done to differentiate it from the fields in CommonParameter.
|
|
GetName() string
|
|
GetType() string
|
|
GetAuthSources() []ParamAuthSource
|
|
Parse(any) (any, error)
|
|
Manifest() ParameterManifest
|
|
}
|
|
|
|
// Parameters is a type used to allow unmarshal a list of parameters
|
|
type Parameters []Parameter
|
|
|
|
func (c *Parameters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
*c = make(Parameters, 0)
|
|
// Parse the 'kind' fields for each source
|
|
var rawList []util.DelayedUnmarshaler
|
|
if err := unmarshal(&rawList); err != nil {
|
|
return err
|
|
}
|
|
for _, u := range rawList {
|
|
p, err := parseParamFromDelayedUnmarshaler(&u)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
(*c) = append((*c), p)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// parseParamFromDelayedUnmarshaler is a helper function that is required to parse
|
|
// parameters because there are multiple different types
|
|
func parseParamFromDelayedUnmarshaler(u *util.DelayedUnmarshaler) (Parameter, error) {
|
|
var p CommonParameter
|
|
err := u.Unmarshal(&p)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parameter missing required fields: %w", err)
|
|
}
|
|
switch p.Type {
|
|
case typeString:
|
|
a := &StringParameter{}
|
|
if err := u.Unmarshal(a); err != nil {
|
|
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
|
}
|
|
return a, nil
|
|
case typeInt:
|
|
a := &IntParameter{}
|
|
if err := u.Unmarshal(a); err != nil {
|
|
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
|
}
|
|
return a, nil
|
|
case typeFloat:
|
|
a := &FloatParameter{}
|
|
if err := u.Unmarshal(a); err != nil {
|
|
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
|
}
|
|
return a, nil
|
|
case typeBool:
|
|
a := &BooleanParameter{}
|
|
if err := u.Unmarshal(a); err != nil {
|
|
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
|
}
|
|
return a, nil
|
|
case typeArray:
|
|
a := &ArrayParameter{}
|
|
if err := u.Unmarshal(a); err != nil {
|
|
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
|
}
|
|
return a, nil
|
|
}
|
|
return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type)
|
|
}
|
|
|
|
func (ps Parameters) Manifest() []ParameterManifest {
|
|
rtn := make([]ParameterManifest, 0, len(ps))
|
|
for _, p := range ps {
|
|
rtn = append(rtn, p.Manifest())
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
// ParameterManifest represents parameters when served as part of a ToolManifest.
|
|
type ParameterManifest struct {
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
Description string `json:"description"`
|
|
AuthSources []string `json:"authSources"`
|
|
// Parameter *ParameterManifest `json:"parameter,omitempty"`
|
|
}
|
|
|
|
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
|
|
type CommonParameter struct {
|
|
Name string `yaml:"name"`
|
|
Type string `yaml:"type"`
|
|
Desc string `yaml:"description"`
|
|
AuthSources []ParamAuthSource `yaml:"authSources"`
|
|
}
|
|
|
|
// GetName returns the name specified for the Parameter.
|
|
func (p *CommonParameter) GetName() string {
|
|
return p.Name
|
|
}
|
|
|
|
// GetType returns the type specified for the Parameter.
|
|
func (p *CommonParameter) GetType() string {
|
|
return p.Type
|
|
}
|
|
|
|
// GetType returns the type specified for the Parameter.
|
|
func (p *CommonParameter) Manifest() ParameterManifest {
|
|
// only list ParamAuthSource names (without fields) in manifest
|
|
authNames := make([]string, len(p.AuthSources))
|
|
for i, a := range p.AuthSources {
|
|
authNames[i] = a.Name
|
|
}
|
|
return ParameterManifest{
|
|
Name: p.Name,
|
|
Type: p.Type,
|
|
Description: p.Desc,
|
|
AuthSources: authNames,
|
|
}
|
|
}
|
|
|
|
// ParseTypeError is a custom error for incorrectly typed Parameters.
|
|
type ParseTypeError struct {
|
|
Name string
|
|
Type string
|
|
Value any
|
|
}
|
|
|
|
func (e ParseTypeError) Error() string {
|
|
return fmt.Sprintf("%q not type %q", e.Value, e.Type)
|
|
}
|
|
|
|
type ParamAuthSource struct {
|
|
Name string `yaml:"name"`
|
|
Field string `yaml:"field"`
|
|
}
|
|
|
|
// NewStringParameter is a convenience function for initializing a StringParameter.
|
|
func NewStringParameter(name, desc string) *StringParameter {
|
|
return &StringParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeString,
|
|
Desc: desc,
|
|
AuthSources: nil,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewStringParameterWithAuth is a convenience function for initializing a StringParameter with a list of ParamAuthSource.
|
|
func NewStringParameterWithAuth(name, desc string, authSources []ParamAuthSource) *StringParameter {
|
|
return &StringParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeString,
|
|
Desc: desc,
|
|
AuthSources: authSources,
|
|
},
|
|
}
|
|
}
|
|
|
|
var _ Parameter = &StringParameter{}
|
|
|
|
// StringParameter is a parameter representing the "string" type.
|
|
type StringParameter struct {
|
|
CommonParameter `yaml:",inline"`
|
|
}
|
|
|
|
// Parse casts the value "v" as a "string".
|
|
func (p *StringParameter) Parse(v any) (any, error) {
|
|
newV, ok := v.(string)
|
|
if !ok {
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
}
|
|
return newV, nil
|
|
}
|
|
func (p *StringParameter) GetAuthSources() []ParamAuthSource {
|
|
return p.AuthSources
|
|
}
|
|
|
|
// NewIntParameter is a convenience function for initializing a IntParameter.
|
|
func NewIntParameter(name, desc string) *IntParameter {
|
|
return &IntParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeInt,
|
|
Desc: desc,
|
|
AuthSources: nil,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewIntParameterWithAuth is a convenience function for initializing a IntParameter with a list of ParamAuthSource.
|
|
func NewIntParameterWithAuth(name, desc string, authSources []ParamAuthSource) *IntParameter {
|
|
return &IntParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeInt,
|
|
Desc: desc,
|
|
AuthSources: authSources,
|
|
},
|
|
}
|
|
}
|
|
|
|
var _ Parameter = &IntParameter{}
|
|
|
|
// IntParameter is a parameter representing the "int" type.
|
|
type IntParameter struct {
|
|
CommonParameter `yaml:",inline"`
|
|
}
|
|
|
|
func (p *IntParameter) Parse(v any) (any, error) {
|
|
var out int
|
|
switch newV := v.(type) {
|
|
default:
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
case int:
|
|
out = int(newV)
|
|
case int32:
|
|
out = int(newV)
|
|
case int64:
|
|
out = int(newV)
|
|
case json.Number:
|
|
newI, err := newV.Int64()
|
|
if err != nil {
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
}
|
|
out = int(newI)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (p *IntParameter) GetAuthSources() []ParamAuthSource {
|
|
return p.AuthSources
|
|
}
|
|
|
|
// NewFloatParameter is a convenience function for initializing a FloatParameter.
|
|
func NewFloatParameter(name, desc string) *FloatParameter {
|
|
return &FloatParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeFloat,
|
|
Desc: desc,
|
|
AuthSources: nil,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewFloatParameterWithAuth is a convenience function for initializing a FloatParameter with a list of ParamAuthSource.
|
|
func NewFloatParameterWithAuth(name, desc string, authSources []ParamAuthSource) *FloatParameter {
|
|
return &FloatParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeFloat,
|
|
Desc: desc,
|
|
AuthSources: authSources,
|
|
},
|
|
}
|
|
}
|
|
|
|
var _ Parameter = &FloatParameter{}
|
|
|
|
// FloatParameter is a parameter representing the "float" type.
|
|
type FloatParameter struct {
|
|
CommonParameter `yaml:",inline"`
|
|
}
|
|
|
|
func (p *FloatParameter) Parse(v any) (any, error) {
|
|
var out float64
|
|
switch newV := v.(type) {
|
|
default:
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
case float32:
|
|
out = float64(newV)
|
|
case float64:
|
|
out = newV
|
|
case json.Number:
|
|
newI, err := newV.Float64()
|
|
if err != nil {
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
}
|
|
out = float64(newI)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (p *FloatParameter) GetAuthSources() []ParamAuthSource {
|
|
return p.AuthSources
|
|
}
|
|
|
|
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
|
|
func NewBooleanParameter(name, desc string) *BooleanParameter {
|
|
return &BooleanParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeBool,
|
|
Desc: desc,
|
|
AuthSources: nil,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewBooleanParameterWithAuth is a convenience function for initializing a BooleanParameter with a list of ParamAuthSource.
|
|
func NewBooleanParameterWithAuth(name, desc string, authSources []ParamAuthSource) *BooleanParameter {
|
|
return &BooleanParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeBool,
|
|
Desc: desc,
|
|
AuthSources: authSources,
|
|
},
|
|
}
|
|
}
|
|
|
|
var _ Parameter = &BooleanParameter{}
|
|
|
|
// BooleanParameter is a parameter representing the "boolean" type.
|
|
type BooleanParameter struct {
|
|
CommonParameter `yaml:",inline"`
|
|
}
|
|
|
|
func (p *BooleanParameter) Parse(v any) (any, error) {
|
|
newV, ok := v.(bool)
|
|
if !ok {
|
|
return nil, &ParseTypeError{p.Name, p.Type, v}
|
|
}
|
|
return newV, nil
|
|
}
|
|
|
|
func (p *BooleanParameter) GetAuthSources() []ParamAuthSource {
|
|
return p.AuthSources
|
|
}
|
|
|
|
// NewArrayParameter is a convenience function for initializing a ArrayParameter.
|
|
func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter {
|
|
return &ArrayParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeArray,
|
|
Desc: desc,
|
|
AuthSources: nil,
|
|
},
|
|
Items: items,
|
|
}
|
|
}
|
|
|
|
// NewArrayParameterWithAuth is a convenience function for initializing a ArrayParameter with a list of ParamAuthSource.
|
|
func NewArrayParameterWithAuth(name, desc string, items Parameter, authSources []ParamAuthSource) *ArrayParameter {
|
|
return &ArrayParameter{
|
|
CommonParameter: CommonParameter{
|
|
Name: name,
|
|
Type: typeArray,
|
|
Desc: desc,
|
|
AuthSources: authSources,
|
|
},
|
|
Items: items,
|
|
}
|
|
}
|
|
|
|
var _ Parameter = &ArrayParameter{}
|
|
|
|
// ArrayParameter is a parameter representing the "array" type.
|
|
type ArrayParameter struct {
|
|
CommonParameter `yaml:",inline"`
|
|
Items Parameter `yaml:"items"`
|
|
}
|
|
|
|
func (p *ArrayParameter) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
if err := unmarshal(&p.CommonParameter); err != nil {
|
|
return err
|
|
}
|
|
var rawItem struct {
|
|
Items util.DelayedUnmarshaler `yaml:"items"`
|
|
}
|
|
if err := unmarshal(&rawItem); err != nil {
|
|
return err
|
|
}
|
|
i, err := parseParamFromDelayedUnmarshaler(&rawItem.Items)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to parse 'items' field: %w", err)
|
|
}
|
|
if i.GetAuthSources() != nil {
|
|
return fmt.Errorf("nested items should not have auth sources.")
|
|
}
|
|
p.Items = i
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ArrayParameter) Parse(v any) (any, error) {
|
|
arrVal, ok := v.([]any)
|
|
if !ok {
|
|
return nil, &ParseTypeError{p.Name, p.Type, arrVal}
|
|
}
|
|
rtn := make([]any, 0, len(arrVal))
|
|
for idx, val := range arrVal {
|
|
val, err := p.Items.Parse(val)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse element #%d: %w", idx, err)
|
|
}
|
|
rtn = append(rtn, val)
|
|
}
|
|
return rtn, nil
|
|
}
|
|
|
|
func (p *ArrayParameter) GetAuthSources() []ParamAuthSource {
|
|
return p.AuthSources
|
|
}
|