feat!: deprecate authsource in favor of authservice (#297)

Rename existing `authSource` to `authService` through deprecation.
`AuthService` more clearly distinguishes it from `Sources` objects.

`authSources` will be converted into `authServices` after the
unmarshalling process. A warning log is shown if `authSources` are used
(for both within tools parameters and defining auth services):
```
2025-02-20T13:57:51.156025-08:00 WARN "`authSources` is deprecated, use `authServices` for parameters instead"
2025-02-20T13:57:51.156569-08:00 WARN "`authSources` is deprecated, use `authServices` instead"
2025-02-20T13:57:52.047584-08:00 INFO "Initialized 1 sources."
...
```

The manifest generated will continue to use `authSources` to keep
compatibility with the sdks:
```
{
"serverVersion":"0.1.0",
"tools":{
  "test_tool2":{
    "description":"Use this tool to test\n",
    "parameters":[{
      "name":"user_id",
      "type":"string",
      "description":"Auto-populated from Google login",
      "authSources":["my-google-auth"]
      }]
    }
  }
}
```



Test cases with `authSources` are kept for compatibility. Will be
removed when `authSources` are no longer supported.
This commit is contained in:
Yuan
2025-02-21 18:36:04 -08:00
committed by GitHub
parent df28036b84
commit 04cb5fbc3e
28 changed files with 561 additions and 248 deletions

View File

@@ -119,10 +119,11 @@ func NewCommand(opts ...Option) *Command {
}
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthSourceConfigs `yaml:"authSources"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
}
// parseToolsFile parses the provided yaml into appropriate configs.
@@ -203,7 +204,12 @@ func run(cmd *Command) error {
return errMsg
}
toolsFile, err := parseToolsFile(ctx, buf)
cmd.cfg.SourceConfigs, cmd.cfg.AuthSourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs = toolsFile.Sources, toolsFile.AuthSources, toolsFile.Tools, toolsFile.Toolsets
cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets
authSourceConfigs := toolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
cmd.cfg.AuthServiceConfigs = authSourceConfigs
}
if err != nil {
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())

View File

@@ -16,7 +16,6 @@ package cmd
import (
"bytes"
"context"
_ "embed"
"os"
"strings"
@@ -261,6 +260,10 @@ func TestDefaultLogLevel(t *testing.T) {
}
func TestParseToolFile(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
description string
in string
@@ -330,15 +333,15 @@ func TestParseToolFile(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.description, func(t *testing.T) {
toolsFile, err := parseToolsFile(context.Background(), testutils.FormatYaml(tc.in))
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("failed to parse input: %v", err)
}
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
t.Fatalf("incorrect sources parse: diff %v", diff)
}
if diff := cmp.Diff(tc.wantToolsFile.AuthSources, toolsFile.AuthSources); diff != "" {
t.Fatalf("incorrect authSources parse: diff %v", diff)
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
t.Fatalf("incorrect authServices parse: diff %v", diff)
}
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
t.Fatalf("incorrect tools parse: diff %v", diff)
@@ -352,6 +355,10 @@ func TestParseToolFile(t *testing.T) {
}
func TestParseToolFileWithAuth(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
description string
in string
@@ -360,6 +367,104 @@ func TestParseToolFileWithAuth(t *testing.T) {
{
description: "basic example",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-service:
kind: google
clientId: my-client-id
other-google-service:
kind: google
clientId: other-client-id
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
- name: id
type: integer
description: user id
authServices:
- name: my-google-service
field: user_id
- name: email
type: string
description: user email
authServices:
- name: my-google-service
field: email
- name: other-google-service
field: other_email
toolsets:
example_toolset:
- example_tool
`,
wantToolsFile: ToolsFile{
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
},
},
},
Toolsets: server.ToolsetConfigs{
"example_toolset": tools.ToolsetConfig{
Name: "example_toolset",
ToolNames: []string{"example_tool"},
},
},
},
},
{
description: "basic example with authSources",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
@@ -421,15 +526,15 @@ func TestParseToolFileWithAuth(t *testing.T) {
Password: "my_pass",
},
},
AuthSources: server.AuthSourceConfigs{
AuthSources: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthSourceKind,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthSourceKind,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
@@ -442,8 +547,8 @@ func TestParseToolFileWithAuth(t *testing.T) {
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthSource{{Name: "my-google-service", Field: "user_id"}}),
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthSource{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
},
},
},
@@ -458,15 +563,15 @@ func TestParseToolFileWithAuth(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.description, func(t *testing.T) {
toolsFile, err := parseToolsFile(context.Background(), testutils.FormatYaml(tc.in))
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("failed to parse input: %v", err)
}
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
t.Fatalf("incorrect sources parse: diff %v", diff)
}
if diff := cmp.Diff(tc.wantToolsFile.AuthSources, toolsFile.AuthSources); diff != "" {
t.Fatalf("incorrect authSources parse: diff %v", diff)
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
t.Fatalf("incorrect authServices parse: diff %v", diff)
}
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
t.Fatalf("incorrect tools parse: diff %v", diff)

View File

@@ -1,12 +1,12 @@
---
title: "AuthSources"
title: "AuthServices"
type: docs
weight: 1
description: >
AuthSources represent services that handle authentication and authorization.
AuthServices represent services that handle authentication and authorization.
---
AuthSources represent services that handle authentication and authorization. It
AuthServices represent services that handle authentication and authorization. It
can primarily be used by [Tools](../tools) in two different ways:
- [**Authorized Invocation**][auth-invoke] is when a tool
@@ -32,7 +32,7 @@ If you are accessing Toolbox with multiple applications, each
{{< /notice >}}
```yaml
authSources:
authServices:
my_auth_app_1:
kind: google
clientId: YOUR_CLIENT_ID_1
@@ -41,17 +41,17 @@ authSources:
clientId: YOUR_CLIENT_ID_2
```
After you've configured an `authSource` you'll, need to reference it in the
After you've configured an `authService` you'll, need to reference it in the
configuration for each tool that should use it:
- **Authorized Invocations** for authorizing a tool call, [use the
`requiredAuth` field in a tool config][auth-invoke]
- **Authenticated Parameters** for using the value from a ODIC claim, [use the
`authSources` field in a parameter config][auth-params]
`authServices` field in a parameter config][auth-params]
## Specifying ID Tokens from Clients
After [configuring](#example) your `authSources` section, use a Toolbox SDK to
After [configuring](#example) your `authServices` section, use a Toolbox SDK to
add your ID tokens to the header of a Tool invocation request. When specifying a
token you will provide a function (that returns an id). This function is called
when the tool is invoked. This allows you to cache and refresh the ID token as
@@ -89,4 +89,4 @@ authorized_tool = tools[0].add_auth_tokens({
{{< /tab >}}
{{< /tabpane >}}
## Kinds of Auth Sources
## Kinds of Auth Services

View File

@@ -12,7 +12,7 @@ Google Sign-In manages the OAuth 2.0 flow and token lifecycle. To integrate the
Google Sign-In workflow to your web app [follow this guide][gsi-setup].
After setting up the Google Sign-In workflow, you should have registered your
application and retrieved a [Client ID][client-id]. Configure your auth source
application and retrieved a [Client ID][client-id]. Configure your auth service
in with the `Client ID`.
[gsi-setup]: https://developers.google.com/identity/sign-in/web/sign-in
@@ -31,7 +31,7 @@ ID.
### Authenticated Parameters
When using [Authenticated Parameters][auth-params], any [claim provided by the
id-token][provided-claims] can be used as a source for the parameter.
id-token][provided-claims] can be used for the parameter.
[auth-params]: ../tools/#authenticated-phugarameters
[provided-claims]:
@@ -40,7 +40,7 @@ id-token][provided-claims] can be used as a source for the parameter.
## Example
```yaml
authSources:
authServices:
my-google-auth:
kind: google
clientId: YOUR_GOOGLE_CLIENT_ID

View File

@@ -115,7 +115,7 @@ Authenticated parameters are automatically populated with user
information decoded from [ID tokens](../authsources/#specifying-id-tokens-from-clients) that
are passed in request headers. They do not take input values in request bodies
like other parameters. To use authenticated parameters, you must configure
the tool to map the required [authSources](../authsources) to
the tool to map the required [authServices](../authservices) to
specific claims within the user's ID token.
```yaml
@@ -129,8 +129,8 @@ specific claims within the user's ID token.
- name: user_id
type: string
description: Auto-populated from Google login
authSources:
# Refer to one of the `authSources` defined
authServices:
# Refer to one of the `authServices` defined
- name: my-google-auth
# `sub` is the OIDC claim field for user ID
field: sub
@@ -138,14 +138,14 @@ specific claims within the user's ID token.
| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|-----------------------------------------------------------------------------------------|
| name | string | true | Name of the [authSources](../authsources) used to verify the OIDC auth token. |
| name | string | true | Name of the [authServices](../authservices) used to verify the OIDC auth token. |
| field | string | true | Claim field decoded from the OIDC token used to auto-populate this parameter. |
## Authorized Invocations
You can require an authorization check for any Tool invocation request by
specifying an `authRequired` field. Specify a list of
[authSources](../authsources) defined in the previous section.
[authServices](../authservices) defined in the previous section.
```yaml
tools:
@@ -154,7 +154,7 @@ tools:
source: my-pg-instance
statement: |
SELECT * FROM flights
# A list of `authSources` defined previously
# A list of `authServices` defined previously
authRequired:
- my-google-auth
- other-auth-service

View File

@@ -16,15 +16,15 @@ package auth
import "net/http"
// SourceConfig is the interface for configuring authentication sources.
type AuthSourceConfig interface {
AuthSourceConfigKind() string
Initialize() (AuthSource, error)
// AuthServiceConfig is the interface for configuring authentication services.
type AuthServiceConfig interface {
AuthServiceConfigKind() string
Initialize() (AuthService, error)
}
// AuthSource is the interface for authentication sources.
type AuthSource interface {
AuthSourceKind() string
// AuthService is the interface for authentication services.
type AuthService interface {
AuthServiceKind() string
GetName() string
GetClaimsFromHeader(http.Header) (map[string]any, error)
}

View File

@@ -23,54 +23,54 @@ import (
"google.golang.org/api/idtoken"
)
const AuthSourceKind string = "google"
const AuthServiceKind string = "google"
// validate interface
var _ auth.AuthSourceConfig = Config{}
var _ auth.AuthServiceConfig = Config{}
// Auth source configuration
// Auth service configuration
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ClientID string `yaml:"clientId" validate:"required"`
}
// Returns the auth source kind
func (cfg Config) AuthSourceConfigKind() string {
return AuthSourceKind
// Returns the auth service kind
func (cfg Config) AuthServiceConfigKind() string {
return AuthServiceKind
}
// Initialize a Google auth source
func (cfg Config) Initialize() (auth.AuthSource, error) {
a := &AuthSource{
// Initialize a Google auth service
func (cfg Config) Initialize() (auth.AuthService, error) {
a := &AuthService{
Name: cfg.Name,
Kind: AuthSourceKind,
Kind: AuthServiceKind,
ClientID: cfg.ClientID,
}
return a, nil
}
var _ auth.AuthSource = AuthSource{}
var _ auth.AuthService = AuthService{}
// struct used to store auth source info
type AuthSource struct {
// struct used to store auth service info
type AuthService struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
ClientID string `yaml:"clientId"`
}
// Returns the auth source kind
func (a AuthSource) AuthSourceKind() string {
return AuthSourceKind
// Returns the auth service kind
func (a AuthService) AuthServiceKind() string {
return AuthServiceKind
}
// Returns the name of the auth source
func (a AuthSource) GetName() string {
// Returns the name of the auth service
func (a AuthService) GetName() string {
return a.Name
}
// Verifies Google ID token and return claims
func (a AuthSource) GetClaimsFromHeader(h http.Header) (map[string]any, error) {
func (a AuthService) GetClaimsFromHeader(h http.Header) (map[string]any, error) {
if token := h.Get(a.Name + "_token"); token != "" {
payload, err := idtoken.Validate(context.Background(), token, a.ClientID)
if err != nil {

View File

@@ -163,31 +163,31 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
// Tool authentication
// claimsFromAuth maps the name of the authsource to the claims retrieved from it.
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
claimsFromAuth := make(map[string]map[string]any)
for _, aS := range s.authSources {
for _, aS := range s.authServices {
claims, err := aS.GetClaimsFromHeader(r.Header)
if err != nil {
s.logger.DebugContext(context.Background(), err.Error())
continue
}
if claims == nil {
// authSource not present in header
// authService not present in header
continue
}
claimsFromAuth[aS.GetName()] = claims
}
// Tool authorization check
verifiedAuthSources := make([]string, len(claimsFromAuth))
verifiedAuthServices := make([]string, len(claimsFromAuth))
i := 0
for k := range claimsFromAuth {
verifiedAuthSources[i] = k
verifiedAuthServices[i] = k
i++
}
// Check if any of the specified auth sources is verified
isAuthorized := tool.Authorized(verifiedAuthSources)
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
s.logger.DebugContext(context.Background(), err.Error())

View File

@@ -57,7 +57,7 @@ func (t MockTool) Manifest() tools.Manifest {
return tools.Manifest{Description: t.Description, Parameters: pMs}
}
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return true
}

View File

@@ -51,8 +51,8 @@ type ServerConfig struct {
Port int
// SourceConfigs defines what sources of data are available for tools.
SourceConfigs SourceConfigs
// AuthSourceConfigs defines what sources of authentication are available for tools.
AuthSourceConfigs AuthSourceConfigs
// AuthServiceConfigs defines what sources of authentication are available for tools.
AuthServiceConfigs AuthServiceConfigs
// ToolConfigs defines what tools are available.
ToolConfigs ToolConfigs
// ToolsetConfigs defines what tools are available.
@@ -220,15 +220,15 @@ func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interf
return nil
}
// AuthSourceConfigs is a type used to allow unmarshal of the data authSource config map
type AuthSourceConfigs map[string]auth.AuthSourceConfig
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
type AuthServiceConfigs map[string]auth.AuthServiceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &AuthSourceConfigs{}
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
func (c *AuthSourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(AuthSourceConfigs)
// Parse the 'kind' fields for each authSource
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(AuthServiceConfigs)
// Parse the 'kind' fields for each authService
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
@@ -250,7 +250,7 @@ func (c *AuthSourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(in
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case google.AuthSourceKind:
case google.AuthServiceKind:
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)

View File

@@ -43,10 +43,10 @@ type Server struct {
logger log.Logger
instrumentation *Instrumentation
sources map[string]sources.Source
authSources map[string]auth.AuthSource
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
sources map[string]sources.Source
authServices map[string]auth.AuthService
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
}
// NewServer returns a Server object based on provided Config.
@@ -120,29 +120,29 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
// initialize and validate the auth sources from configs
authSourcesMap := make(map[string]auth.AuthSource)
for name, sc := range cfg.AuthSourceConfigs {
a, err := func() (auth.AuthSource, error) {
// initialize and validate the auth services from configs
authServicesMap := make(map[string]auth.AuthService)
for name, sc := range cfg.AuthServiceConfigs {
a, err := func() (auth.AuthService, error) {
_, span := instrumentation.Tracer.Start(
parentCtx,
"toolbox/server/auth/init",
trace.WithAttributes(attribute.String("auth_kind", sc.AuthSourceConfigKind())),
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
trace.WithAttributes(attribute.String("auth_name", name)),
)
defer span.End()
a, err := sc.Initialize()
if err != nil {
return nil, fmt.Errorf("unable to initialize auth source %q: %w", name, err)
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
}
return a, nil
}()
if err != nil {
return nil, err
}
authSourcesMap[name] = a
authServicesMap[name] = a
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authSources.", len(authSourcesMap)))
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
// initialize and validate the tools from configs
toolsMap := make(map[string]tools.Tool)
@@ -211,10 +211,10 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
logger: l,
instrumentation: instrumentation,
sources: sourcesMap,
authSources: authSourcesMap,
tools: toolsMap,
toolsets: toolsetsMap,
sources: sourcesMap,
authServices: authServicesMap,
tools: toolsMap,
toolsets: toolsetsMap,
}
// control plane
apiR, err := apiRouter(s)

View File

@@ -15,7 +15,13 @@
package testutils
import (
"context"
"fmt"
"os"
"strings"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/util"
)
// formatYaml is a utility function for stripping out tabs in multiline strings
@@ -26,3 +32,13 @@ func FormatYaml(in string) []byte {
in = strings.ReplaceAll(in, "\t", " ")
return []byte(in)
}
// ContextWithNewLogger create a new context with new logger
func ContextWithNewLogger() (context.Context, error) {
ctx := context.Background()
logger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
if err != nil {
return nil, fmt.Errorf("unable to create logger: %s", err)
}
return util.WithLogger(ctx, logger), nil
}

View File

@@ -127,6 +127,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -25,6 +25,10 @@ import (
)
func TestParseFromYamlDgraph(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -83,7 +87,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -169,6 +169,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -26,6 +26,10 @@ import (
)
func TestParseFromYamlMssql(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -48,7 +52,7 @@ func TestParseFromYamlMssql(t *testing.T) {
- name: country
type: string
description: some description
authSources:
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
@@ -64,7 +68,7 @@ func TestParseFromYamlMssql(t *testing.T) {
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
},
@@ -77,7 +81,7 @@ func TestParseFromYamlMssql(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -164,6 +164,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -26,6 +26,10 @@ import (
)
func TestParseFromYamlMySQL(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -48,7 +52,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
- name: country
type: string
description: some description
authSources:
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
@@ -64,7 +68,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
},
@@ -77,7 +81,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -127,6 +127,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -26,6 +26,10 @@ import (
)
func TestParseFromYamlNeo4j(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -66,7 +70,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -97,12 +97,12 @@ func (p ParamValues) AsMapWithDollarPrefix() map[string]interface{} {
return params
}
func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) {
// parse a parameter from claims using its specified auth sources
for _, a := range paramAuthSources {
func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[string]map[string]any) (any, error) {
// parse a parameter from claims using its specified auth services
for _, a := range paramAuthServices {
claims, ok := claimsMap[a.Name]
if !ok {
// not validated for this authsource, skip to the next one
// not validated for this authservice, skip to the next one
continue
}
v, ok := claims[a.Field]
@@ -120,9 +120,9 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
params := make([]ParamValue, 0, len(ps))
for _, p := range ps {
var v any
paramAuthSources := p.GetAuthSources()
paramAuthServices := p.GetAuthServices()
name := p.GetName()
if paramAuthSources == nil {
if paramAuthServices == nil {
// parse non auth-required parameter
var ok bool
v, ok = data[name]
@@ -132,7 +132,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
} else {
// parse authenticated parameter
var err error
v, err = parseFromAuthSource(paramAuthSources, claimsMap)
v, err = parseFromAuthService(paramAuthServices, claimsMap)
if err != nil {
return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err)
}
@@ -151,7 +151,7 @@ type Parameter interface {
// but this is done to differentiate it from the fields in CommonParameter.
GetName() string
GetType() string
GetAuthSources() []ParamAuthSource
GetAuthServices() []ParamAuthService
Parse(any) (any, error)
Manifest() ParameterManifest
}
@@ -161,7 +161,6 @@ type Parameters []Parameter
func (c *Parameters) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(Parameters, 0)
// Parse the 'kind' fields for each source
var rawList []util.DelayedUnmarshaler
if err := unmarshal(&rawList); err != nil {
return err
@@ -194,36 +193,65 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
if err != nil {
return nil, fmt.Errorf("error creating decoder: %w", err)
}
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, err
}
switch t {
case typeString:
a := &StringParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
case typeInt:
a := &IntParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
case typeFloat:
a := &FloatParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
case typeBool:
a := &BooleanParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
case typeArray:
a := &ArrayParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
}
return nil, fmt.Errorf("%q is not valid type for a parameter!", t)
@@ -239,19 +267,20 @@ func (ps Parameters) Manifest() []ParameterManifest {
// ParameterManifest represents parameters when served as part of a ToolManifest.
type ParameterManifest struct {
Name string `json:"name"`
Type string `json:"type"`
Description string `json:"description"`
AuthSources []string `json:"authSources"`
Items *ParameterManifest `json:"items,omitempty"`
Name string `json:"name"`
Type string `json:"type"`
Description string `json:"description"`
AuthServices []string `json:"authSources"`
Items *ParameterManifest `json:"items,omitempty"`
}
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
type CommonParameter struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Desc string `yaml:"description" validate:"required"`
AuthSources []ParamAuthSource `yaml:"authSources"`
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Desc string `yaml:"description" validate:"required"`
AuthServices []ParamAuthService `yaml:"authServices"`
AuthSources []ParamAuthService `yaml:"authSources"` // Deprecated: Kept for compatibility.
}
// GetName returns the name specified for the Parameter.
@@ -266,16 +295,16 @@ func (p *CommonParameter) GetType() string {
// Manifest returns the manifest for the Parameter.
func (p *CommonParameter) Manifest() ParameterManifest {
// only list ParamAuthSource names (without fields) in manifest
authNames := make([]string, len(p.AuthSources))
for i, a := range p.AuthSources {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Description: p.Desc,
AuthSources: authNames,
Name: p.Name,
Type: p.Type,
Description: p.Desc,
AuthServices: authNames,
}
}
@@ -290,7 +319,7 @@ func (e ParseTypeError) Error() string {
return fmt.Sprintf("%q not type %q", e.Value, e.Type)
}
type ParamAuthSource struct {
type ParamAuthService struct {
Name string `yaml:"name"`
Field string `yaml:"field"`
}
@@ -299,22 +328,22 @@ type ParamAuthSource struct {
func NewStringParameter(name, desc string) *StringParameter {
return &StringParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeString,
Desc: desc,
AuthSources: nil,
Name: name,
Type: typeString,
Desc: desc,
AuthServices: nil,
},
}
}
// NewStringParameterWithAuth is a convenience function for initializing a StringParameter with a list of ParamAuthSource.
func NewStringParameterWithAuth(name, desc string, authSources []ParamAuthSource) *StringParameter {
// NewStringParameterWithAuth is a convenience function for initializing a StringParameter with a list of ParamAuthService.
func NewStringParameterWithAuth(name, desc string, authServices []ParamAuthService) *StringParameter {
return &StringParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeString,
Desc: desc,
AuthSources: authSources,
Name: name,
Type: typeString,
Desc: desc,
AuthServices: authServices,
},
}
}
@@ -334,30 +363,30 @@ func (p *StringParameter) Parse(v any) (any, error) {
}
return newV, nil
}
func (p *StringParameter) GetAuthSources() []ParamAuthSource {
return p.AuthSources
func (p *StringParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
// NewIntParameter is a convenience function for initializing a IntParameter.
func NewIntParameter(name, desc string) *IntParameter {
return &IntParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeInt,
Desc: desc,
AuthSources: nil,
Name: name,
Type: typeInt,
Desc: desc,
AuthServices: nil,
},
}
}
// NewIntParameterWithAuth is a convenience function for initializing a IntParameter with a list of ParamAuthSource.
func NewIntParameterWithAuth(name, desc string, authSources []ParamAuthSource) *IntParameter {
// NewIntParameterWithAuth is a convenience function for initializing a IntParameter with a list of ParamAuthService.
func NewIntParameterWithAuth(name, desc string, authServices []ParamAuthService) *IntParameter {
return &IntParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeInt,
Desc: desc,
AuthSources: authSources,
Name: name,
Type: typeInt,
Desc: desc,
AuthServices: authServices,
},
}
}
@@ -390,30 +419,30 @@ func (p *IntParameter) Parse(v any) (any, error) {
return out, nil
}
func (p *IntParameter) GetAuthSources() []ParamAuthSource {
return p.AuthSources
func (p *IntParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
// NewFloatParameter is a convenience function for initializing a FloatParameter.
func NewFloatParameter(name, desc string) *FloatParameter {
return &FloatParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeFloat,
Desc: desc,
AuthSources: nil,
Name: name,
Type: typeFloat,
Desc: desc,
AuthServices: nil,
},
}
}
// NewFloatParameterWithAuth is a convenience function for initializing a FloatParameter with a list of ParamAuthSource.
func NewFloatParameterWithAuth(name, desc string, authSources []ParamAuthSource) *FloatParameter {
// NewFloatParameterWithAuth is a convenience function for initializing a FloatParameter with a list of ParamAuthService.
func NewFloatParameterWithAuth(name, desc string, authServices []ParamAuthService) *FloatParameter {
return &FloatParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeFloat,
Desc: desc,
AuthSources: authSources,
Name: name,
Type: typeFloat,
Desc: desc,
AuthServices: authServices,
},
}
}
@@ -444,30 +473,30 @@ func (p *FloatParameter) Parse(v any) (any, error) {
return out, nil
}
func (p *FloatParameter) GetAuthSources() []ParamAuthSource {
return p.AuthSources
func (p *FloatParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
func NewBooleanParameter(name, desc string) *BooleanParameter {
return &BooleanParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeBool,
Desc: desc,
AuthSources: nil,
Name: name,
Type: typeBool,
Desc: desc,
AuthServices: nil,
},
}
}
// NewBooleanParameterWithAuth is a convenience function for initializing a BooleanParameter with a list of ParamAuthSource.
func NewBooleanParameterWithAuth(name, desc string, authSources []ParamAuthSource) *BooleanParameter {
// NewBooleanParameterWithAuth is a convenience function for initializing a BooleanParameter with a list of ParamAuthService.
func NewBooleanParameterWithAuth(name, desc string, authServices []ParamAuthService) *BooleanParameter {
return &BooleanParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeBool,
Desc: desc,
AuthSources: authSources,
Name: name,
Type: typeBool,
Desc: desc,
AuthServices: authServices,
},
}
}
@@ -487,31 +516,31 @@ func (p *BooleanParameter) Parse(v any) (any, error) {
return newV, nil
}
func (p *BooleanParameter) GetAuthSources() []ParamAuthSource {
return p.AuthSources
func (p *BooleanParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
// NewArrayParameter is a convenience function for initializing a ArrayParameter.
func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter {
return &ArrayParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeArray,
Desc: desc,
AuthSources: nil,
Name: name,
Type: typeArray,
Desc: desc,
AuthServices: nil,
},
Items: items,
}
}
// NewArrayParameterWithAuth is a convenience function for initializing a ArrayParameter with a list of ParamAuthSource.
func NewArrayParameterWithAuth(name, desc string, items Parameter, authSources []ParamAuthSource) *ArrayParameter {
// NewArrayParameterWithAuth is a convenience function for initializing a ArrayParameter with a list of ParamAuthService.
func NewArrayParameterWithAuth(name, desc string, items Parameter, authServices []ParamAuthService) *ArrayParameter {
return &ArrayParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeArray,
Desc: desc,
AuthSources: authSources,
Name: name,
Type: typeArray,
Desc: desc,
AuthServices: authServices,
},
Items: items,
}
@@ -538,8 +567,8 @@ func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(inter
if err != nil {
return fmt.Errorf("unable to parse 'items' field: %w", err)
}
if i.GetAuthSources() != nil {
return fmt.Errorf("nested items should not have auth sources.")
if i.GetAuthServices() != nil {
return fmt.Errorf("nested items should not have auth services.")
}
p.Items = i
@@ -562,23 +591,23 @@ func (p *ArrayParameter) Parse(v any) (any, error) {
return rtn, nil
}
func (p *ArrayParameter) GetAuthSources() []ParamAuthSource {
return p.AuthSources
func (p *ArrayParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
// Manifest returns the manifest for the ArrayParameter.
func (p *ArrayParameter) Manifest() ParameterManifest {
// only list ParamAuthSource names (without fields) in manifest
authNames := make([]string, len(p.AuthSources))
for i, a := range p.AuthSources {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
items := p.Items.Manifest()
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Description: p.Desc,
AuthSources: authNames,
Items: &items,
Name: p.Name,
Type: p.Type,
Description: p.Desc,
AuthServices: authNames,
Items: &items,
}
}

View File

@@ -23,10 +23,15 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
)
func TestParametersMarshal(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
name string
in []map[string]any
@@ -130,7 +135,7 @@ func TestParametersMarshal(t *testing.T) {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
err = yaml.Unmarshal(data, &got)
err = yaml.UnmarshalContext(ctx, data, &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
@@ -142,7 +147,11 @@ func TestParametersMarshal(t *testing.T) {
}
func TestAuthParametersMarshal(t *testing.T) {
authSources := []tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"}, {Name: "other-auth-service", Field: "user_id"}}
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
authServices := []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, {Name: "other-auth-service", Field: "user_id"}}
tcs := []struct {
name string
in []map[string]any
@@ -150,6 +159,29 @@ func TestAuthParametersMarshal(t *testing.T) {
}{
{
name: "string",
in: []map[string]any{
{
"name": "my_string",
"type": "string",
"description": "this param is a string",
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
},
},
{
name: "string with authSources",
in: []map[string]any{
{
"name": "my_string",
@@ -168,11 +200,34 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
},
},
{
name: "int",
in: []map[string]any{
{
"name": "my_integer",
"type": "integer",
"description": "this param is an int",
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authServices),
},
},
{
name: "int with authSources",
in: []map[string]any{
{
"name": "my_integer",
@@ -191,11 +246,34 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authSources),
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authServices),
},
},
{
name: "float",
in: []map[string]any{
{
"name": "my_float",
"type": "float",
"description": "my param is a float",
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authServices),
},
},
{
name: "float with authSources",
in: []map[string]any{
{
"name": "my_float",
@@ -214,11 +292,34 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authSources),
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authServices),
},
},
{
name: "bool",
in: []map[string]any{
{
"name": "my_bool",
"type": "boolean",
"description": "this param is a boolean",
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authServices),
},
},
{
name: "bool with authSources",
in: []map[string]any{
{
"name": "my_bool",
@@ -237,11 +338,39 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authSources),
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authServices),
},
},
{
name: "string array",
in: []map[string]any{
{
"name": "my_array",
"type": "array",
"description": "this param is an array of strings",
"items": map[string]string{
"name": "my_string",
"type": "string",
"description": "string item",
},
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authServices),
},
},
{
name: "string array with authSources",
in: []map[string]any{
{
"name": "my_array",
@@ -265,7 +394,7 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authSources),
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authServices),
},
},
{
@@ -280,7 +409,7 @@ func TestAuthParametersMarshal(t *testing.T) {
"type": "float",
"description": "float item",
},
"authSources": []map[string]string{
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
@@ -293,7 +422,7 @@ func TestAuthParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewArrayParameterWithAuth("my_array", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item"), authSources),
tools.NewArrayParameterWithAuth("my_array", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item"), authServices),
},
},
}
@@ -306,7 +435,7 @@ func TestAuthParametersMarshal(t *testing.T) {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
err = yaml.Unmarshal(data, &got)
err = yaml.UnmarshalContext(ctx, data, &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
@@ -454,7 +583,7 @@ func TestParametersParse(t *testing.T) {
}
func TestAuthParametersParse(t *testing.T) {
authSources := []tools.ParamAuthSource{
authServices := []tools.ParamAuthService{
{
Name: "my-google-auth-service",
Field: "auth_field",
@@ -473,7 +602,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "string",
params: tools.Parameters{
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
},
in: map[string]any{
"my_string": "hello world",
@@ -484,7 +613,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "not string",
params: tools.Parameters{
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
},
in: map[string]any{
"my_string": 4,
@@ -494,7 +623,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "int",
params: tools.Parameters{
tools.NewIntParameterWithAuth("my_int", "this param is an int", authSources),
tools.NewIntParameterWithAuth("my_int", "this param is an int", authServices),
},
in: map[string]any{
"my_int": 100,
@@ -505,7 +634,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "not int",
params: tools.Parameters{
tools.NewIntParameterWithAuth("my_int", "this param is an int", authSources),
tools.NewIntParameterWithAuth("my_int", "this param is an int", authServices),
},
in: map[string]any{
"my_int": 14.5,
@@ -515,7 +644,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "float",
params: tools.Parameters{
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authSources),
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authServices),
},
in: map[string]any{
"my_float": 1.5,
@@ -526,7 +655,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "not float",
params: tools.Parameters{
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authSources),
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authServices),
},
in: map[string]any{
"my_float": true,
@@ -536,7 +665,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "bool",
params: tools.Parameters{
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authSources),
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authServices),
},
in: map[string]any{
"my_bool": true,
@@ -547,7 +676,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "not bool",
params: tools.Parameters{
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authSources),
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authServices),
},
in: map[string]any{
"my_bool": 1.5,
@@ -557,7 +686,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "username",
params: tools.Parameters{
tools.NewStringParameterWithAuth("username", "username string", authSources),
tools.NewStringParameterWithAuth("username", "username string", authServices),
},
in: map[string]any{
"username": "Violet",
@@ -568,7 +697,7 @@ func TestAuthParametersParse(t *testing.T) {
{
name: "expect claim error",
params: tools.Parameters{
tools.NewStringParameterWithAuth("username", "username string", authSources),
tools.NewStringParameterWithAuth("username", "username string", authServices),
},
in: map[string]any{
"username": "Violet",
@@ -679,32 +808,32 @@ func TestParamManifest(t *testing.T) {
{
name: "string",
in: tools.NewStringParameter("foo-string", "bar"),
want: tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthSources: []string{}},
want: tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthServices: []string{}},
},
{
name: "int",
in: tools.NewIntParameter("foo-int", "bar"),
want: tools.ParameterManifest{Name: "foo-int", Type: "integer", Description: "bar", AuthSources: []string{}},
want: tools.ParameterManifest{Name: "foo-int", Type: "integer", Description: "bar", AuthServices: []string{}},
},
{
name: "float",
in: tools.NewFloatParameter("foo-float", "bar"),
want: tools.ParameterManifest{Name: "foo-float", Type: "float", Description: "bar", AuthSources: []string{}},
want: tools.ParameterManifest{Name: "foo-float", Type: "float", Description: "bar", AuthServices: []string{}},
},
{
name: "boolean",
in: tools.NewBooleanParameter("foo-bool", "bar"),
want: tools.ParameterManifest{Name: "foo-bool", Type: "boolean", Description: "bar", AuthSources: []string{}},
want: tools.ParameterManifest{Name: "foo-bool", Type: "boolean", Description: "bar", AuthServices: []string{}},
},
{
name: "array",
in: tools.NewArrayParameter("foo-array", "bar", tools.NewStringParameter("foo-string", "bar")),
want: tools.ParameterManifest{
Name: "foo-array",
Type: "array",
Description: "bar",
AuthSources: []string{},
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthSources: []string{}},
Name: "foo-array",
Type: "array",
Description: "bar",
AuthServices: []string{},
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthServices: []string{}},
},
},
}
@@ -719,6 +848,10 @@ func TestParamManifest(t *testing.T) {
}
func TestFailParametersUnmarshal(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
name string
in []map[string]any
@@ -790,7 +923,7 @@ func TestFailParametersUnmarshal(t *testing.T) {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
err = yaml.Unmarshal(data, &got)
err = yaml.UnmarshalContext(ctx, data, &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -141,6 +141,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -26,6 +26,10 @@ import (
)
func TestParseFromYamlPostgres(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -48,7 +52,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
- name: country
type: string
description: some description
authSources:
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
@@ -64,7 +68,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
},
@@ -77,7 +81,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -170,6 +170,6 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -26,6 +26,10 @@ import (
)
func TestParseFromYamlSpanner(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
@@ -66,7 +70,7 @@ func TestParseFromYamlSpanner(t *testing.T) {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

View File

@@ -39,13 +39,13 @@ type Manifest struct {
}
// Helper function that returns if a tool invocation request is authorized
func IsAuthorized(authRequiredSources []string, verifiedAuthSources []string) bool {
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
if len(authRequiredSources) == 0 {
// no authorization requirement
return true
}
for _, a := range authRequiredSources {
if slices.Contains(verifiedAuthSources, a) {
if slices.Contains(verifiedAuthServices, a) {
return true
}
}

View File

@@ -31,7 +31,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
"sources": map[string]any{
"my-instance": sourceConfig,
},
"authSources": map[string]any{
"authServices": map[string]any{
"my-google-auth": map[string]any{
"kind": "google",
"clientId": ClientId,
@@ -73,7 +73,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
"name": "email",
"type": "string",
"description": "user email",
"authSources": []map[string]string{
"authServices": []map[string]string{
{
"name": "my-google-auth",
"field": "email",