mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
feat!: deprecate authsource in favor of authservice (#297)
Rename existing `authSource` to `authService` through deprecation.
`AuthService` more clearly distinguishes it from `Sources` objects.
`authSources` will be converted into `authServices` after the
unmarshalling process. A warning log is shown if `authSources` are used
(for both within tools parameters and defining auth services):
```
2025-02-20T13:57:51.156025-08:00 WARN "`authSources` is deprecated, use `authServices` for parameters instead"
2025-02-20T13:57:51.156569-08:00 WARN "`authSources` is deprecated, use `authServices` instead"
2025-02-20T13:57:52.047584-08:00 INFO "Initialized 1 sources."
...
```
The manifest generated will continue to use `authSources` to keep
compatibility with the sdks:
```
{
"serverVersion":"0.1.0",
"tools":{
"test_tool2":{
"description":"Use this tool to test\n",
"parameters":[{
"name":"user_id",
"type":"string",
"description":"Auto-populated from Google login",
"authSources":["my-google-auth"]
}]
}
}
}
```
Test cases with `authSources` are kept for compatibility. Will be
removed when `authSources` are no longer supported.
This commit is contained in:
16
cmd/root.go
16
cmd/root.go
@@ -119,10 +119,11 @@ func NewCommand(opts ...Option) *Command {
|
||||
}
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthSourceConfigs `yaml:"authSources"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
@@ -203,7 +204,12 @@ func run(cmd *Command) error {
|
||||
return errMsg
|
||||
}
|
||||
toolsFile, err := parseToolsFile(ctx, buf)
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.AuthSourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs = toolsFile.Sources, toolsFile.AuthSources, toolsFile.Tools, toolsFile.Toolsets
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets
|
||||
authSourceConfigs := toolsFile.AuthSources
|
||||
if authSourceConfigs != nil {
|
||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
cmd.cfg.AuthServiceConfigs = authSourceConfigs
|
||||
}
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
|
||||
129
cmd/root_test.go
129
cmd/root_test.go
@@ -16,7 +16,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -261,6 +260,10 @@ func TestDefaultLogLevel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
description string
|
||||
in string
|
||||
@@ -330,15 +333,15 @@ func TestParseToolFile(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
toolsFile, err := parseToolsFile(context.Background(), testutils.FormatYaml(tc.in))
|
||||
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse input: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
||||
t.Fatalf("incorrect sources parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.AuthSources, toolsFile.AuthSources); diff != "" {
|
||||
t.Fatalf("incorrect authSources parse: diff %v", diff)
|
||||
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
||||
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
||||
t.Fatalf("incorrect tools parse: diff %v", diff)
|
||||
@@ -352,6 +355,10 @@ func TestParseToolFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToolFileWithAuth(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
description string
|
||||
in string
|
||||
@@ -360,6 +367,104 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
{
|
||||
description: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-service:
|
||||
kind: google
|
||||
clientId: my-client-id
|
||||
other-google-service:
|
||||
kind: google
|
||||
clientId: other-client-id
|
||||
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
- name: id
|
||||
type: integer
|
||||
description: user id
|
||||
authServices:
|
||||
- name: my-google-service
|
||||
field: user_id
|
||||
- name: email
|
||||
type: string
|
||||
description: user email
|
||||
authServices:
|
||||
- name: my-google-service
|
||||
field: email
|
||||
- name: other-google-service
|
||||
field: other_email
|
||||
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
`,
|
||||
wantToolsFile: ToolsFile{
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
||||
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
Toolsets: server.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "basic example with authSources",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
@@ -421,15 +526,15 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthSources: server.AuthSourceConfigs{
|
||||
AuthSources: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthSourceKind,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthSourceKind,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
@@ -442,8 +547,8 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthSource{{Name: "my-google-service", Field: "user_id"}}),
|
||||
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthSource{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
||||
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -458,15 +563,15 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
toolsFile, err := parseToolsFile(context.Background(), testutils.FormatYaml(tc.in))
|
||||
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse input: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
||||
t.Fatalf("incorrect sources parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.AuthSources, toolsFile.AuthSources); diff != "" {
|
||||
t.Fatalf("incorrect authSources parse: diff %v", diff)
|
||||
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
||||
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
||||
t.Fatalf("incorrect tools parse: diff %v", diff)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
---
|
||||
title: "AuthSources"
|
||||
title: "AuthServices"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
AuthSources represent services that handle authentication and authorization.
|
||||
AuthServices represent services that handle authentication and authorization.
|
||||
---
|
||||
|
||||
AuthSources represent services that handle authentication and authorization. It
|
||||
AuthServices represent services that handle authentication and authorization. It
|
||||
can primarily be used by [Tools](../tools) in two different ways:
|
||||
|
||||
- [**Authorized Invocation**][auth-invoke] is when a tool
|
||||
@@ -32,7 +32,7 @@ If you are accessing Toolbox with multiple applications, each
|
||||
{{< /notice >}}
|
||||
|
||||
```yaml
|
||||
authSources:
|
||||
authServices:
|
||||
my_auth_app_1:
|
||||
kind: google
|
||||
clientId: YOUR_CLIENT_ID_1
|
||||
@@ -41,17 +41,17 @@ authSources:
|
||||
clientId: YOUR_CLIENT_ID_2
|
||||
```
|
||||
|
||||
After you've configured an `authSource` you'll, need to reference it in the
|
||||
After you've configured an `authService` you'll, need to reference it in the
|
||||
configuration for each tool that should use it:
|
||||
- **Authorized Invocations** for authorizing a tool call, [use the
|
||||
`requiredAuth` field in a tool config][auth-invoke]
|
||||
- **Authenticated Parameters** for using the value from a ODIC claim, [use the
|
||||
`authSources` field in a parameter config][auth-params]
|
||||
`authServices` field in a parameter config][auth-params]
|
||||
|
||||
|
||||
## Specifying ID Tokens from Clients
|
||||
|
||||
After [configuring](#example) your `authSources` section, use a Toolbox SDK to
|
||||
After [configuring](#example) your `authServices` section, use a Toolbox SDK to
|
||||
add your ID tokens to the header of a Tool invocation request. When specifying a
|
||||
token you will provide a function (that returns an id). This function is called
|
||||
when the tool is invoked. This allows you to cache and refresh the ID token as
|
||||
@@ -89,4 +89,4 @@ authorized_tool = tools[0].add_auth_tokens({
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
## Kinds of Auth Sources
|
||||
## Kinds of Auth Services
|
||||
@@ -12,7 +12,7 @@ Google Sign-In manages the OAuth 2.0 flow and token lifecycle. To integrate the
|
||||
Google Sign-In workflow to your web app [follow this guide][gsi-setup].
|
||||
|
||||
After setting up the Google Sign-In workflow, you should have registered your
|
||||
application and retrieved a [Client ID][client-id]. Configure your auth source
|
||||
application and retrieved a [Client ID][client-id]. Configure your auth service
|
||||
in with the `Client ID`.
|
||||
|
||||
[gsi-setup]: https://developers.google.com/identity/sign-in/web/sign-in
|
||||
@@ -31,7 +31,7 @@ ID.
|
||||
### Authenticated Parameters
|
||||
|
||||
When using [Authenticated Parameters][auth-params], any [claim provided by the
|
||||
id-token][provided-claims] can be used as a source for the parameter.
|
||||
id-token][provided-claims] can be used for the parameter.
|
||||
|
||||
[auth-params]: ../tools/#authenticated-phugarameters
|
||||
[provided-claims]:
|
||||
@@ -40,7 +40,7 @@ id-token][provided-claims] can be used as a source for the parameter.
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
authSources:
|
||||
authServices:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: YOUR_GOOGLE_CLIENT_ID
|
||||
@@ -115,7 +115,7 @@ Authenticated parameters are automatically populated with user
|
||||
information decoded from [ID tokens](../authsources/#specifying-id-tokens-from-clients) that
|
||||
are passed in request headers. They do not take input values in request bodies
|
||||
like other parameters. To use authenticated parameters, you must configure
|
||||
the tool to map the required [authSources](../authsources) to
|
||||
the tool to map the required [authServices](../authservices) to
|
||||
specific claims within the user's ID token.
|
||||
|
||||
```yaml
|
||||
@@ -129,8 +129,8 @@ specific claims within the user's ID token.
|
||||
- name: user_id
|
||||
type: string
|
||||
description: Auto-populated from Google login
|
||||
authSources:
|
||||
# Refer to one of the `authSources` defined
|
||||
authServices:
|
||||
# Refer to one of the `authServices` defined
|
||||
- name: my-google-auth
|
||||
# `sub` is the OIDC claim field for user ID
|
||||
field: sub
|
||||
@@ -138,14 +138,14 @@ specific claims within the user's ID token.
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|-----------------------------------------------------------------------------------------|
|
||||
| name | string | true | Name of the [authSources](../authsources) used to verify the OIDC auth token. |
|
||||
| name | string | true | Name of the [authServices](../authservices) used to verify the OIDC auth token. |
|
||||
| field | string | true | Claim field decoded from the OIDC token used to auto-populate this parameter. |
|
||||
|
||||
## Authorized Invocations
|
||||
|
||||
You can require an authorization check for any Tool invocation request by
|
||||
specifying an `authRequired` field. Specify a list of
|
||||
[authSources](../authsources) defined in the previous section.
|
||||
[authServices](../authservices) defined in the previous section.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
@@ -154,7 +154,7 @@ tools:
|
||||
source: my-pg-instance
|
||||
statement: |
|
||||
SELECT * FROM flights
|
||||
# A list of `authSources` defined previously
|
||||
# A list of `authServices` defined previously
|
||||
authRequired:
|
||||
- my-google-auth
|
||||
- other-auth-service
|
||||
|
||||
@@ -16,15 +16,15 @@ package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
// SourceConfig is the interface for configuring authentication sources.
|
||||
type AuthSourceConfig interface {
|
||||
AuthSourceConfigKind() string
|
||||
Initialize() (AuthSource, error)
|
||||
// AuthServiceConfig is the interface for configuring authentication services.
|
||||
type AuthServiceConfig interface {
|
||||
AuthServiceConfigKind() string
|
||||
Initialize() (AuthService, error)
|
||||
}
|
||||
|
||||
// AuthSource is the interface for authentication sources.
|
||||
type AuthSource interface {
|
||||
AuthSourceKind() string
|
||||
// AuthService is the interface for authentication services.
|
||||
type AuthService interface {
|
||||
AuthServiceKind() string
|
||||
GetName() string
|
||||
GetClaimsFromHeader(http.Header) (map[string]any, error)
|
||||
}
|
||||
|
||||
@@ -23,54 +23,54 @@ import (
|
||||
"google.golang.org/api/idtoken"
|
||||
)
|
||||
|
||||
const AuthSourceKind string = "google"
|
||||
const AuthServiceKind string = "google"
|
||||
|
||||
// validate interface
|
||||
var _ auth.AuthSourceConfig = Config{}
|
||||
var _ auth.AuthServiceConfig = Config{}
|
||||
|
||||
// Auth source configuration
|
||||
// Auth service configuration
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ClientID string `yaml:"clientId" validate:"required"`
|
||||
}
|
||||
|
||||
// Returns the auth source kind
|
||||
func (cfg Config) AuthSourceConfigKind() string {
|
||||
return AuthSourceKind
|
||||
// Returns the auth service kind
|
||||
func (cfg Config) AuthServiceConfigKind() string {
|
||||
return AuthServiceKind
|
||||
}
|
||||
|
||||
// Initialize a Google auth source
|
||||
func (cfg Config) Initialize() (auth.AuthSource, error) {
|
||||
a := &AuthSource{
|
||||
// Initialize a Google auth service
|
||||
func (cfg Config) Initialize() (auth.AuthService, error) {
|
||||
a := &AuthService{
|
||||
Name: cfg.Name,
|
||||
Kind: AuthSourceKind,
|
||||
Kind: AuthServiceKind,
|
||||
ClientID: cfg.ClientID,
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
var _ auth.AuthSource = AuthSource{}
|
||||
var _ auth.AuthService = AuthService{}
|
||||
|
||||
// struct used to store auth source info
|
||||
type AuthSource struct {
|
||||
// struct used to store auth service info
|
||||
type AuthService struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
ClientID string `yaml:"clientId"`
|
||||
}
|
||||
|
||||
// Returns the auth source kind
|
||||
func (a AuthSource) AuthSourceKind() string {
|
||||
return AuthSourceKind
|
||||
// Returns the auth service kind
|
||||
func (a AuthService) AuthServiceKind() string {
|
||||
return AuthServiceKind
|
||||
}
|
||||
|
||||
// Returns the name of the auth source
|
||||
func (a AuthSource) GetName() string {
|
||||
// Returns the name of the auth service
|
||||
func (a AuthService) GetName() string {
|
||||
return a.Name
|
||||
}
|
||||
|
||||
// Verifies Google ID token and return claims
|
||||
func (a AuthSource) GetClaimsFromHeader(h http.Header) (map[string]any, error) {
|
||||
func (a AuthService) GetClaimsFromHeader(h http.Header) (map[string]any, error) {
|
||||
if token := h.Get(a.Name + "_token"); token != "" {
|
||||
payload, err := idtoken.Validate(context.Background(), token, a.ClientID)
|
||||
if err != nil {
|
||||
|
||||
@@ -163,31 +163,31 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Tool authentication
|
||||
// claimsFromAuth maps the name of the authsource to the claims retrieved from it.
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
for _, aS := range s.authSources {
|
||||
for _, aS := range s.authServices {
|
||||
claims, err := aS.GetClaimsFromHeader(r.Header)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(context.Background(), err.Error())
|
||||
continue
|
||||
}
|
||||
if claims == nil {
|
||||
// authSource not present in header
|
||||
// authService not present in header
|
||||
continue
|
||||
}
|
||||
claimsFromAuth[aS.GetName()] = claims
|
||||
}
|
||||
|
||||
// Tool authorization check
|
||||
verifiedAuthSources := make([]string, len(claimsFromAuth))
|
||||
verifiedAuthServices := make([]string, len(claimsFromAuth))
|
||||
i := 0
|
||||
for k := range claimsFromAuth {
|
||||
verifiedAuthSources[i] = k
|
||||
verifiedAuthServices[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
// Check if any of the specified auth sources is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthSources)
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
|
||||
s.logger.DebugContext(context.Background(), err.Error())
|
||||
|
||||
@@ -57,7 +57,7 @@ func (t MockTool) Manifest() tools.Manifest {
|
||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||
}
|
||||
|
||||
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
|
||||
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ type ServerConfig struct {
|
||||
Port int
|
||||
// SourceConfigs defines what sources of data are available for tools.
|
||||
SourceConfigs SourceConfigs
|
||||
// AuthSourceConfigs defines what sources of authentication are available for tools.
|
||||
AuthSourceConfigs AuthSourceConfigs
|
||||
// AuthServiceConfigs defines what sources of authentication are available for tools.
|
||||
AuthServiceConfigs AuthServiceConfigs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs ToolConfigs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
@@ -220,15 +220,15 @@ func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interf
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthSourceConfigs is a type used to allow unmarshal of the data authSource config map
|
||||
type AuthSourceConfigs map[string]auth.AuthSourceConfig
|
||||
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
|
||||
type AuthServiceConfigs map[string]auth.AuthServiceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &AuthSourceConfigs{}
|
||||
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
|
||||
|
||||
func (c *AuthSourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(AuthSourceConfigs)
|
||||
// Parse the 'kind' fields for each authSource
|
||||
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(AuthServiceConfigs)
|
||||
// Parse the 'kind' fields for each authService
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
@@ -250,7 +250,7 @@ func (c *AuthSourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(in
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case google.AuthSourceKind:
|
||||
case google.AuthServiceKind:
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
|
||||
@@ -43,10 +43,10 @@ type Server struct {
|
||||
logger log.Logger
|
||||
instrumentation *Instrumentation
|
||||
|
||||
sources map[string]sources.Source
|
||||
authSources map[string]auth.AuthSource
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
@@ -120,29 +120,29 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
|
||||
|
||||
// initialize and validate the auth sources from configs
|
||||
authSourcesMap := make(map[string]auth.AuthSource)
|
||||
for name, sc := range cfg.AuthSourceConfigs {
|
||||
a, err := func() (auth.AuthSource, error) {
|
||||
// initialize and validate the auth services from configs
|
||||
authServicesMap := make(map[string]auth.AuthService)
|
||||
for name, sc := range cfg.AuthServiceConfigs {
|
||||
a, err := func() (auth.AuthService, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
parentCtx,
|
||||
"toolbox/server/auth/init",
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthSourceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
a, err := sc.Initialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize auth source %q: %w", name, err)
|
||||
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
|
||||
}
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authSourcesMap[name] = a
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authSources.", len(authSourcesMap)))
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
@@ -211,10 +211,10 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
logger: l,
|
||||
instrumentation: instrumentation,
|
||||
|
||||
sources: sourcesMap,
|
||||
authSources: authSourcesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
}
|
||||
// control plane
|
||||
apiR, err := apiRouter(s)
|
||||
|
||||
@@ -15,7 +15,13 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
// formatYaml is a utility function for stripping out tabs in multiline strings
|
||||
@@ -26,3 +32,13 @@ func FormatYaml(in string) []byte {
|
||||
in = strings.ReplaceAll(in, "\t", " ")
|
||||
return []byte(in)
|
||||
}
|
||||
|
||||
// ContextWithNewLogger create a new context with new logger
|
||||
func ContextWithNewLogger() (context.Context, error) {
|
||||
ctx := context.Background()
|
||||
logger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create logger: %s", err)
|
||||
}
|
||||
return util.WithLogger(ctx, logger), nil
|
||||
}
|
||||
|
||||
@@ -127,6 +127,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlDgraph(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -83,7 +87,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -169,6 +169,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlMssql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -48,7 +52,7 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authSources:
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
@@ -64,7 +68,7 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
@@ -77,7 +81,7 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlMySQL(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -48,7 +52,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authSources:
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
@@ -64,7 +68,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
@@ -77,7 +81,7 @@ func TestParseFromYamlMySQL(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -127,6 +127,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -66,7 +70,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -97,12 +97,12 @@ func (p ParamValues) AsMapWithDollarPrefix() map[string]interface{} {
|
||||
return params
|
||||
}
|
||||
|
||||
func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) {
|
||||
// parse a parameter from claims using its specified auth sources
|
||||
for _, a := range paramAuthSources {
|
||||
func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[string]map[string]any) (any, error) {
|
||||
// parse a parameter from claims using its specified auth services
|
||||
for _, a := range paramAuthServices {
|
||||
claims, ok := claimsMap[a.Name]
|
||||
if !ok {
|
||||
// not validated for this authsource, skip to the next one
|
||||
// not validated for this authservice, skip to the next one
|
||||
continue
|
||||
}
|
||||
v, ok := claims[a.Field]
|
||||
@@ -120,9 +120,9 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
|
||||
params := make([]ParamValue, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
var v any
|
||||
paramAuthSources := p.GetAuthSources()
|
||||
paramAuthServices := p.GetAuthServices()
|
||||
name := p.GetName()
|
||||
if paramAuthSources == nil {
|
||||
if paramAuthServices == nil {
|
||||
// parse non auth-required parameter
|
||||
var ok bool
|
||||
v, ok = data[name]
|
||||
@@ -132,7 +132,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
|
||||
} else {
|
||||
// parse authenticated parameter
|
||||
var err error
|
||||
v, err = parseFromAuthSource(paramAuthSources, claimsMap)
|
||||
v, err = parseFromAuthService(paramAuthServices, claimsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err)
|
||||
}
|
||||
@@ -151,7 +151,7 @@ type Parameter interface {
|
||||
// but this is done to differentiate it from the fields in CommonParameter.
|
||||
GetName() string
|
||||
GetType() string
|
||||
GetAuthSources() []ParamAuthSource
|
||||
GetAuthServices() []ParamAuthService
|
||||
Parse(any) (any, error)
|
||||
Manifest() ParameterManifest
|
||||
}
|
||||
@@ -161,7 +161,6 @@ type Parameters []Parameter
|
||||
|
||||
func (c *Parameters) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(Parameters, 0)
|
||||
// Parse the 'kind' fields for each source
|
||||
var rawList []util.DelayedUnmarshaler
|
||||
if err := unmarshal(&rawList); err != nil {
|
||||
return err
|
||||
@@ -194,36 +193,65 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t {
|
||||
case typeString:
|
||||
a := &StringParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeInt:
|
||||
a := &IntParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeFloat:
|
||||
a := &FloatParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeBool:
|
||||
a := &BooleanParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeArray:
|
||||
a := &ArrayParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%q is not valid type for a parameter!", t)
|
||||
@@ -239,19 +267,20 @@ func (ps Parameters) Manifest() []ParameterManifest {
|
||||
|
||||
// ParameterManifest represents parameters when served as part of a ToolManifest.
|
||||
type ParameterManifest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
AuthSources []string `json:"authSources"`
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
AuthServices []string `json:"authSources"`
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
}
|
||||
|
||||
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
|
||||
type CommonParameter struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Desc string `yaml:"description" validate:"required"`
|
||||
AuthSources []ParamAuthSource `yaml:"authSources"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Desc string `yaml:"description" validate:"required"`
|
||||
AuthServices []ParamAuthService `yaml:"authServices"`
|
||||
AuthSources []ParamAuthService `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
}
|
||||
|
||||
// GetName returns the name specified for the Parameter.
|
||||
@@ -266,16 +295,16 @@ func (p *CommonParameter) GetType() string {
|
||||
|
||||
// Manifest returns the manifest for the Parameter.
|
||||
func (p *CommonParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthSource names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthSources))
|
||||
for i, a := range p.AuthSources {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
AuthSources: authNames,
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,7 +319,7 @@ func (e ParseTypeError) Error() string {
|
||||
return fmt.Sprintf("%q not type %q", e.Value, e.Type)
|
||||
}
|
||||
|
||||
type ParamAuthSource struct {
|
||||
type ParamAuthService struct {
|
||||
Name string `yaml:"name"`
|
||||
Field string `yaml:"field"`
|
||||
}
|
||||
@@ -299,22 +328,22 @@ type ParamAuthSource struct {
|
||||
func NewStringParameter(name, desc string) *StringParameter {
|
||||
return &StringParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeString,
|
||||
Desc: desc,
|
||||
AuthSources: nil,
|
||||
Name: name,
|
||||
Type: typeString,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStringParameterWithAuth is a convenience function for initializing a StringParameter with a list of ParamAuthSource.
|
||||
func NewStringParameterWithAuth(name, desc string, authSources []ParamAuthSource) *StringParameter {
|
||||
// NewStringParameterWithAuth is a convenience function for initializing a StringParameter with a list of ParamAuthService.
|
||||
func NewStringParameterWithAuth(name, desc string, authServices []ParamAuthService) *StringParameter {
|
||||
return &StringParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeString,
|
||||
Desc: desc,
|
||||
AuthSources: authSources,
|
||||
Name: name,
|
||||
Type: typeString,
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -334,30 +363,30 @@ func (p *StringParameter) Parse(v any) (any, error) {
|
||||
}
|
||||
return newV, nil
|
||||
}
|
||||
func (p *StringParameter) GetAuthSources() []ParamAuthSource {
|
||||
return p.AuthSources
|
||||
func (p *StringParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
// NewIntParameter is a convenience function for initializing a IntParameter.
|
||||
func NewIntParameter(name, desc string) *IntParameter {
|
||||
return &IntParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeInt,
|
||||
Desc: desc,
|
||||
AuthSources: nil,
|
||||
Name: name,
|
||||
Type: typeInt,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewIntParameterWithAuth is a convenience function for initializing a IntParameter with a list of ParamAuthSource.
|
||||
func NewIntParameterWithAuth(name, desc string, authSources []ParamAuthSource) *IntParameter {
|
||||
// NewIntParameterWithAuth is a convenience function for initializing a IntParameter with a list of ParamAuthService.
|
||||
func NewIntParameterWithAuth(name, desc string, authServices []ParamAuthService) *IntParameter {
|
||||
return &IntParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeInt,
|
||||
Desc: desc,
|
||||
AuthSources: authSources,
|
||||
Name: name,
|
||||
Type: typeInt,
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -390,30 +419,30 @@ func (p *IntParameter) Parse(v any) (any, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *IntParameter) GetAuthSources() []ParamAuthSource {
|
||||
return p.AuthSources
|
||||
func (p *IntParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
// NewFloatParameter is a convenience function for initializing a FloatParameter.
|
||||
func NewFloatParameter(name, desc string) *FloatParameter {
|
||||
return &FloatParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeFloat,
|
||||
Desc: desc,
|
||||
AuthSources: nil,
|
||||
Name: name,
|
||||
Type: typeFloat,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewFloatParameterWithAuth is a convenience function for initializing a FloatParameter with a list of ParamAuthSource.
|
||||
func NewFloatParameterWithAuth(name, desc string, authSources []ParamAuthSource) *FloatParameter {
|
||||
// NewFloatParameterWithAuth is a convenience function for initializing a FloatParameter with a list of ParamAuthService.
|
||||
func NewFloatParameterWithAuth(name, desc string, authServices []ParamAuthService) *FloatParameter {
|
||||
return &FloatParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeFloat,
|
||||
Desc: desc,
|
||||
AuthSources: authSources,
|
||||
Name: name,
|
||||
Type: typeFloat,
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -444,30 +473,30 @@ func (p *FloatParameter) Parse(v any) (any, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *FloatParameter) GetAuthSources() []ParamAuthSource {
|
||||
return p.AuthSources
|
||||
func (p *FloatParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
|
||||
func NewBooleanParameter(name, desc string) *BooleanParameter {
|
||||
return &BooleanParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeBool,
|
||||
Desc: desc,
|
||||
AuthSources: nil,
|
||||
Name: name,
|
||||
Type: typeBool,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewBooleanParameterWithAuth is a convenience function for initializing a BooleanParameter with a list of ParamAuthSource.
|
||||
func NewBooleanParameterWithAuth(name, desc string, authSources []ParamAuthSource) *BooleanParameter {
|
||||
// NewBooleanParameterWithAuth is a convenience function for initializing a BooleanParameter with a list of ParamAuthService.
|
||||
func NewBooleanParameterWithAuth(name, desc string, authServices []ParamAuthService) *BooleanParameter {
|
||||
return &BooleanParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeBool,
|
||||
Desc: desc,
|
||||
AuthSources: authSources,
|
||||
Name: name,
|
||||
Type: typeBool,
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -487,31 +516,31 @@ func (p *BooleanParameter) Parse(v any) (any, error) {
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
func (p *BooleanParameter) GetAuthSources() []ParamAuthSource {
|
||||
return p.AuthSources
|
||||
func (p *BooleanParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
// NewArrayParameter is a convenience function for initializing a ArrayParameter.
|
||||
func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter {
|
||||
return &ArrayParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Desc: desc,
|
||||
AuthSources: nil,
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Items: items,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayParameterWithAuth is a convenience function for initializing a ArrayParameter with a list of ParamAuthSource.
|
||||
func NewArrayParameterWithAuth(name, desc string, items Parameter, authSources []ParamAuthSource) *ArrayParameter {
|
||||
// NewArrayParameterWithAuth is a convenience function for initializing a ArrayParameter with a list of ParamAuthService.
|
||||
func NewArrayParameterWithAuth(name, desc string, items Parameter, authServices []ParamAuthService) *ArrayParameter {
|
||||
return &ArrayParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Desc: desc,
|
||||
AuthSources: authSources,
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
Items: items,
|
||||
}
|
||||
@@ -538,8 +567,8 @@ func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(inter
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse 'items' field: %w", err)
|
||||
}
|
||||
if i.GetAuthSources() != nil {
|
||||
return fmt.Errorf("nested items should not have auth sources.")
|
||||
if i.GetAuthServices() != nil {
|
||||
return fmt.Errorf("nested items should not have auth services.")
|
||||
}
|
||||
p.Items = i
|
||||
|
||||
@@ -562,23 +591,23 @@ func (p *ArrayParameter) Parse(v any) (any, error) {
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (p *ArrayParameter) GetAuthSources() []ParamAuthSource {
|
||||
return p.AuthSources
|
||||
func (p *ArrayParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the ArrayParameter.
|
||||
func (p *ArrayParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthSource names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthSources))
|
||||
for i, a := range p.AuthSources {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
items := p.Items.Manifest()
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
AuthSources: authNames,
|
||||
Items: &items,
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
Items: &items,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,15 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
func TestParametersMarshal(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []map[string]any
|
||||
@@ -130,7 +135,7 @@ func TestParametersMarshal(t *testing.T) {
|
||||
t.Fatalf("unable to marshal input to yaml: %s", err)
|
||||
}
|
||||
// parse bytes to object
|
||||
err = yaml.Unmarshal(data, &got)
|
||||
err = yaml.UnmarshalContext(ctx, data, &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
@@ -142,7 +147,11 @@ func TestParametersMarshal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthParametersMarshal(t *testing.T) {
|
||||
authSources := []tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"}, {Name: "other-auth-service", Field: "user_id"}}
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
authServices := []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, {Name: "other-auth-service", Field: "user_id"}}
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []map[string]any
|
||||
@@ -150,6 +159,29 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_string",
|
||||
"type": "string",
|
||||
"description": "this param is a string",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with authSources",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_string",
|
||||
@@ -168,11 +200,34 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_integer",
|
||||
"type": "integer",
|
||||
"description": "this param is an int",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int with authSources",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_integer",
|
||||
@@ -191,11 +246,34 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authSources),
|
||||
tools.NewIntParameterWithAuth("my_integer", "this param is an int", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_float",
|
||||
"type": "float",
|
||||
"description": "my param is a float",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float with authSources",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_float",
|
||||
@@ -214,11 +292,34 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authSources),
|
||||
tools.NewFloatParameterWithAuth("my_float", "my param is a float", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_bool",
|
||||
"type": "boolean",
|
||||
"description": "this param is a boolean",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bool with authSources",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_bool",
|
||||
@@ -237,11 +338,39 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authSources),
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a boolean", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string array",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_array",
|
||||
"type": "array",
|
||||
"description": "this param is an array of strings",
|
||||
"items": map[string]string{
|
||||
"name": "my_string",
|
||||
"type": "string",
|
||||
"description": "string item",
|
||||
},
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string array with authSources",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_array",
|
||||
@@ -265,7 +394,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authSources),
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item"), authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -280,7 +409,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
"type": "float",
|
||||
"description": "float item",
|
||||
},
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -293,7 +422,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item"), authSources),
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item"), authServices),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -306,7 +435,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
t.Fatalf("unable to marshal input to yaml: %s", err)
|
||||
}
|
||||
// parse bytes to object
|
||||
err = yaml.Unmarshal(data, &got)
|
||||
err = yaml.UnmarshalContext(ctx, data, &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
@@ -454,7 +583,7 @@ func TestParametersParse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthParametersParse(t *testing.T) {
|
||||
authSources := []tools.ParamAuthSource{
|
||||
authServices := []tools.ParamAuthService{
|
||||
{
|
||||
Name: "my-google-auth-service",
|
||||
Field: "auth_field",
|
||||
@@ -473,7 +602,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "string",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string": "hello world",
|
||||
@@ -484,7 +613,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "not string",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authSources),
|
||||
tools.NewStringParameterWithAuth("my_string", "this param is a string", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string": 4,
|
||||
@@ -494,7 +623,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "int",
|
||||
params: tools.Parameters{
|
||||
tools.NewIntParameterWithAuth("my_int", "this param is an int", authSources),
|
||||
tools.NewIntParameterWithAuth("my_int", "this param is an int", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_int": 100,
|
||||
@@ -505,7 +634,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "not int",
|
||||
params: tools.Parameters{
|
||||
tools.NewIntParameterWithAuth("my_int", "this param is an int", authSources),
|
||||
tools.NewIntParameterWithAuth("my_int", "this param is an int", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_int": 14.5,
|
||||
@@ -515,7 +644,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "float",
|
||||
params: tools.Parameters{
|
||||
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authSources),
|
||||
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_float": 1.5,
|
||||
@@ -526,7 +655,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "not float",
|
||||
params: tools.Parameters{
|
||||
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authSources),
|
||||
tools.NewFloatParameterWithAuth("my_float", "this param is a float", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_float": true,
|
||||
@@ -536,7 +665,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "bool",
|
||||
params: tools.Parameters{
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authSources),
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_bool": true,
|
||||
@@ -547,7 +676,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "not bool",
|
||||
params: tools.Parameters{
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authSources),
|
||||
tools.NewBooleanParameterWithAuth("my_bool", "this param is a bool", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_bool": 1.5,
|
||||
@@ -557,7 +686,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "username",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("username", "username string", authSources),
|
||||
tools.NewStringParameterWithAuth("username", "username string", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"username": "Violet",
|
||||
@@ -568,7 +697,7 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
{
|
||||
name: "expect claim error",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameterWithAuth("username", "username string", authSources),
|
||||
tools.NewStringParameterWithAuth("username", "username string", authServices),
|
||||
},
|
||||
in: map[string]any{
|
||||
"username": "Violet",
|
||||
@@ -679,32 +808,32 @@ func TestParamManifest(t *testing.T) {
|
||||
{
|
||||
name: "string",
|
||||
in: tools.NewStringParameter("foo-string", "bar"),
|
||||
want: tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthSources: []string{}},
|
||||
want: tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
in: tools.NewIntParameter("foo-int", "bar"),
|
||||
want: tools.ParameterManifest{Name: "foo-int", Type: "integer", Description: "bar", AuthSources: []string{}},
|
||||
want: tools.ParameterManifest{Name: "foo-int", Type: "integer", Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
in: tools.NewFloatParameter("foo-float", "bar"),
|
||||
want: tools.ParameterManifest{Name: "foo-float", Type: "float", Description: "bar", AuthSources: []string{}},
|
||||
want: tools.ParameterManifest{Name: "foo-float", Type: "float", Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
in: tools.NewBooleanParameter("foo-bool", "bar"),
|
||||
want: tools.ParameterManifest{Name: "foo-bool", Type: "boolean", Description: "bar", AuthSources: []string{}},
|
||||
want: tools.ParameterManifest{Name: "foo-bool", Type: "boolean", Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
in: tools.NewArrayParameter("foo-array", "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-array",
|
||||
Type: "array",
|
||||
Description: "bar",
|
||||
AuthSources: []string{},
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthSources: []string{}},
|
||||
Name: "foo-array",
|
||||
Type: "array",
|
||||
Description: "bar",
|
||||
AuthServices: []string{},
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -719,6 +848,10 @@ func TestParamManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFailParametersUnmarshal(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []map[string]any
|
||||
@@ -790,7 +923,7 @@ func TestFailParametersUnmarshal(t *testing.T) {
|
||||
t.Fatalf("unable to marshal input to yaml: %s", err)
|
||||
}
|
||||
// parse bytes to object
|
||||
err = yaml.Unmarshal(data, &got)
|
||||
err = yaml.UnmarshalContext(ctx, data, &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -141,6 +141,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlPostgres(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -48,7 +52,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authSources:
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
@@ -64,7 +68,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
@@ -77,7 +81,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -170,6 +170,6 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYamlSpanner(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -66,7 +70,7 @@ func TestParseFromYamlSpanner(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
@@ -39,13 +39,13 @@ type Manifest struct {
|
||||
}
|
||||
|
||||
// Helper function that returns if a tool invocation request is authorized
|
||||
func IsAuthorized(authRequiredSources []string, verifiedAuthSources []string) bool {
|
||||
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
|
||||
if len(authRequiredSources) == 0 {
|
||||
// no authorization requirement
|
||||
return true
|
||||
}
|
||||
for _, a := range authRequiredSources {
|
||||
if slices.Contains(verifiedAuthSources, a) {
|
||||
if slices.Contains(verifiedAuthServices, a) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"authSources": map[string]any{
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": ClientId,
|
||||
@@ -73,7 +73,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
|
||||
"name": "email",
|
||||
"type": "string",
|
||||
"description": "user email",
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth",
|
||||
"field": "email",
|
||||
|
||||
Reference in New Issue
Block a user