mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-14 09:57:58 -05:00
Compare commits
12 Commits
host-error
...
debug-mac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c62647c0f | ||
|
|
19465c5359 | ||
|
|
6906fa98c2 | ||
|
|
b36fa6ef4b | ||
|
|
8b0b4e6391 | ||
|
|
e908dd204a | ||
|
|
5818cacfb2 | ||
|
|
84006f0dad | ||
|
|
9acd4492d1 | ||
|
|
5a1c495187 | ||
|
|
a044c468d0 | ||
|
|
8b0a18352e |
@@ -151,6 +151,8 @@ execute `toolbox` to start the server:
|
|||||||
```sh
|
```sh
|
||||||
./toolbox --tools-file "tools.yaml"
|
./toolbox --tools-file "tools.yaml"
|
||||||
```
|
```
|
||||||
|
> [!NOTE]
|
||||||
|
> Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||||
|
|
||||||
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
||||||
terminate signal (`ctrl+c` on most platforms).
|
terminate signal (`ctrl+c` on most platforms).
|
||||||
|
|||||||
233
cmd/root.go
233
cmd/root.go
@@ -19,20 +19,26 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
|
||||||
// Import tool packages for side effect of registration
|
// Import tool packages for side effect of registration
|
||||||
@@ -178,6 +184,7 @@ func NewCommand(opts ...Option) *Command {
|
|||||||
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||||
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", "Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: 'alloydb-postgres', 'bigquery', 'cloud-sql-mysql', 'cloud-sql-postgres', 'cloud-sql-mssql', 'postgres', 'spanner', 'spanner-postgres'.")
|
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", "Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: 'alloydb-postgres', 'bigquery', 'cloud-sql-mysql', 'cloud-sql-postgres', 'cloud-sql-mssql', 'postgres', 'spanner', 'spanner-postgres'.")
|
||||||
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
||||||
|
flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
|
||||||
|
|
||||||
// wrap RunE command so that we have access to original Command object
|
// wrap RunE command so that we have access to original Command object
|
||||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||||
@@ -347,7 +354,7 @@ func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile,
|
|||||||
|
|
||||||
// Combine both file lists
|
// Combine both file lists
|
||||||
allFiles := append(yamlFiles, ymlFiles...)
|
allFiles := append(yamlFiles, ymlFiles...)
|
||||||
|
|
||||||
if len(allFiles) == 0 {
|
if len(allFiles) == 0 {
|
||||||
return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath)
|
return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath)
|
||||||
}
|
}
|
||||||
@@ -356,6 +363,177 @@ func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile,
|
|||||||
return loadAndMergeToolsFiles(ctx, allFiles)
|
return loadAndMergeToolsFiles(ctx, allFiles)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Server) error {
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
|
||||||
|
logger.WarnContext(ctx, errMsg.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
|
||||||
|
func validateReloadEdits(
|
||||||
|
ctx context.Context, toolsFile ToolsFile,
|
||||||
|
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, error,
|
||||||
|
) {
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.DebugContext(ctx, "Attempting to parse and validate reloaded tools file.")
|
||||||
|
|
||||||
|
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/reload")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
reloadedConfig := server.ServerConfig{
|
||||||
|
Version: versionString,
|
||||||
|
SourceConfigs: toolsFile.Sources,
|
||||||
|
AuthServiceConfigs: toolsFile.AuthServices,
|
||||||
|
ToolConfigs: toolsFile.Tools,
|
||||||
|
ToolsetConfigs: toolsFile.Toolsets,
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
|
||||||
|
logger.WarnContext(ctx, errMsg.Error())
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
|
||||||
|
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server) {
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnContext(ctx, "error setting up new watcher %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
watchingFolder := false
|
||||||
|
var folderToWatch string
|
||||||
|
|
||||||
|
// if watchedFiles is empty, indicates that user passed entire folder instead
|
||||||
|
if len(watchedFiles) == 0 {
|
||||||
|
watchingFolder = true
|
||||||
|
|
||||||
|
// validate that watchDirs only has single element
|
||||||
|
if len(watchDirs) > 1 {
|
||||||
|
logger.WarnContext(ctx, "error setting watcher, expected single tools folder if no file(s) are defined.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for onlyKey := range watchDirs {
|
||||||
|
folderToWatch = onlyKey
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for dir := range watchDirs {
|
||||||
|
err := w.Add(dir)
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnContext(ctx, fmt.Sprintf("Error adding path %s to watcher: %s", dir, err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("Added directory %s to watcher.", dir))
|
||||||
|
}
|
||||||
|
|
||||||
|
// debounce timer is used to prevent multiple writes triggering multiple reloads
|
||||||
|
debounceDelay := 100 * time.Millisecond
|
||||||
|
debounce := time.NewTimer(1 * time.Minute)
|
||||||
|
debounce.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.DebugContext(ctx, "file watcher context cancelled")
|
||||||
|
return
|
||||||
|
case err, ok := <-w.Errors:
|
||||||
|
if !ok {
|
||||||
|
logger.WarnContext(ctx, "file watcher was closed unexpectedly")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnContext(ctx, "file watcher error %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case e, ok := <-w.Events:
|
||||||
|
if !ok {
|
||||||
|
logger.WarnContext(ctx, "file watcher already closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// only check for write events which indicate user saved a new tools file
|
||||||
|
if !e.Has(fsnotify.Write) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanedFilename := filepath.Clean(e.Name)
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("WRITE event detected in %s", cleanedFilename))
|
||||||
|
|
||||||
|
folderChanged := watchingFolder &&
|
||||||
|
(strings.HasSuffix(cleanedFilename, ".yaml") || strings.HasSuffix(cleanedFilename, ".yml"))
|
||||||
|
|
||||||
|
if folderChanged || watchedFiles[cleanedFilename] {
|
||||||
|
// indicates the write event is on a relevant file
|
||||||
|
debounce.Reset(debounceDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-debounce.C:
|
||||||
|
debounce.Stop()
|
||||||
|
var reloadedToolsFile ToolsFile
|
||||||
|
|
||||||
|
if watchingFolder {
|
||||||
|
logger.DebugContext(ctx, "Reloading tools folder.")
|
||||||
|
reloadedToolsFile, err = loadAndMergeToolsFolder(ctx, folderToWatch)
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnContext(ctx, "error loading tools folder %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.DebugContext(ctx, "Reloading tools file(s).")
|
||||||
|
reloadedToolsFile, err = loadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles)))
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnContext(ctx, "error loading tools files %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = handleDynamicReload(ctx, reloadedToolsFile, s)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to parse reloaded tools file at %q: %w", reloadedToolsFile, err)
|
||||||
|
logger.WarnContext(ctx, errMsg.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// updateLogLevel checks if Toolbox have to update the existing log level set by users.
|
// updateLogLevel checks if Toolbox have to update the existing log level set by users.
|
||||||
// stdio doesn't support "debug" and "info" logs.
|
// stdio doesn't support "debug" and "info" logs.
|
||||||
func updateLogLevel(stdio bool, logLevel string) bool {
|
func updateLogLevel(stdio bool, logLevel string) bool {
|
||||||
@@ -370,6 +548,33 @@ func updateLogLevel(stdio bool, logLevel string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveWatcherInputs(toolsFile string, toolsFiles []string, toolsFolder string) (map[string]bool, map[string]bool) {
|
||||||
|
var relevantFiles []string
|
||||||
|
|
||||||
|
// map for efficiently checking if a file is relevant
|
||||||
|
watchedFiles := make(map[string]bool)
|
||||||
|
|
||||||
|
// dirs that will be added to watcher (fsnotify prefers watching directory then filtering for file)
|
||||||
|
watchDirs := make(map[string]bool)
|
||||||
|
|
||||||
|
if len(toolsFiles) > 0 {
|
||||||
|
relevantFiles = toolsFiles
|
||||||
|
} else if toolsFolder != "" {
|
||||||
|
watchDirs[filepath.Clean(toolsFolder)] = true
|
||||||
|
} else {
|
||||||
|
relevantFiles = []string{toolsFile}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract parent dir for relevant files and dedup
|
||||||
|
for _, f := range relevantFiles {
|
||||||
|
cleanFile := filepath.Clean(f)
|
||||||
|
watchedFiles[cleanFile] = true
|
||||||
|
watchDirs[filepath.Dir(cleanFile)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return watchDirs, watchedFiles
|
||||||
|
}
|
||||||
|
|
||||||
func run(cmd *Command) error {
|
func run(cmd *Command) error {
|
||||||
if updateLogLevel(cmd.cfg.Stdio, cmd.cfg.LogLevel.String()) {
|
if updateLogLevel(cmd.cfg.Stdio, cmd.cfg.LogLevel.String()) {
|
||||||
cmd.cfg.LogLevel = server.StringLevel(log.Warn)
|
cmd.cfg.LogLevel = server.StringLevel(log.Warn)
|
||||||
@@ -466,6 +671,7 @@ func run(cmd *Command) error {
|
|||||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use multiple tools files
|
// Use multiple tools files
|
||||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
|
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
|
||||||
var err error
|
var err error
|
||||||
@@ -481,6 +687,7 @@ func run(cmd *Command) error {
|
|||||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use tools folder
|
// Use tools folder
|
||||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
|
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
|
||||||
var err error
|
var err error
|
||||||
@@ -494,6 +701,7 @@ func run(cmd *Command) error {
|
|||||||
if cmd.tools_file == "" {
|
if cmd.tools_file == "" {
|
||||||
cmd.tools_file = "tools.yaml"
|
cmd.tools_file = "tools.yaml"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read single tool file contents
|
// Read single tool file contents
|
||||||
buf, err := os.ReadFile(cmd.tools_file)
|
buf, err := os.ReadFile(cmd.tools_file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -516,9 +724,23 @@ func run(cmd *Command) error {
|
|||||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||||
cmd.cfg.AuthServiceConfigs = authSourceConfigs
|
cmd.cfg.AuthServiceConfigs = authSourceConfigs
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||||
|
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||||
|
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||||
|
|
||||||
// start server
|
// start server
|
||||||
s, err := server.NewServer(ctx, cmd.cfg, cmd.logger)
|
s, err := server.NewServer(ctx, cmd.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("toolbox failed to initialize: %w", err)
|
errMsg := fmt.Errorf("toolbox failed to initialize: %w", err)
|
||||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||||
@@ -553,6 +775,13 @@ func run(cmd *Command) error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
|
||||||
|
|
||||||
|
if !cmd.cfg.DisableReload {
|
||||||
|
// start watching the file(s) or folder for changes to trigger dynamic reloading
|
||||||
|
go watchChanges(ctx, watchDirs, watchedFiles, s)
|
||||||
|
}
|
||||||
|
|
||||||
// wait for either the server to error out or the command's context to be canceled
|
// wait for either the server to error out or the command's context to be canceled
|
||||||
select {
|
select {
|
||||||
case err := <-srvErr:
|
case err := <-srvErr:
|
||||||
|
|||||||
196
cmd/root_test.go
196
cmd/root_test.go
@@ -16,23 +16,33 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/http"
|
"github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -174,6 +184,13 @@ func TestServerConfigFlags(t *testing.T) {
|
|||||||
Stdio: true,
|
Stdio: true,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "disable reload",
|
||||||
|
args: []string{"--disable-reload"},
|
||||||
|
want: withDefaults(server.ServerConfig{
|
||||||
|
DisableReload: true,
|
||||||
|
}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
@@ -965,6 +982,185 @@ func TestEnvVarReplacement(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeFilepaths is a helper function to allow same filepath formats for Mac and Windows.
|
||||||
|
// this prevents needing multiple "want" cases for TestResolveWatcherInputs
|
||||||
|
func normalizeFilepaths(m map[string]bool) map[string]bool {
|
||||||
|
newMap := make(map[string]bool)
|
||||||
|
for k, v := range m {
|
||||||
|
newMap[filepath.ToSlash(k)] = v
|
||||||
|
}
|
||||||
|
return newMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveWatcherInputs(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
description string
|
||||||
|
toolsFile string
|
||||||
|
toolsFiles []string
|
||||||
|
toolsFolder string
|
||||||
|
wantWatchDirs map[string]bool
|
||||||
|
wantWatchedFiles map[string]bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
description: "single tools file",
|
||||||
|
toolsFile: "tools_folder/example_tools.yaml",
|
||||||
|
toolsFiles: []string{},
|
||||||
|
toolsFolder: "",
|
||||||
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||||
|
wantWatchedFiles: map[string]bool{"tools_folder/example_tools.yaml": true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "default tools file (root dir)",
|
||||||
|
toolsFile: "tools.yaml",
|
||||||
|
toolsFiles: []string{},
|
||||||
|
toolsFolder: "",
|
||||||
|
wantWatchDirs: map[string]bool{".": true},
|
||||||
|
wantWatchedFiles: map[string]bool{"tools.yaml": true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "multiple files in different folders",
|
||||||
|
toolsFile: "",
|
||||||
|
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder2/example_tools.yaml"},
|
||||||
|
toolsFolder: "",
|
||||||
|
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder2": true},
|
||||||
|
wantWatchedFiles: map[string]bool{
|
||||||
|
"tools_folder/example_tools.yaml": true,
|
||||||
|
"tools_folder2/example_tools.yaml": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "multiple files in same folder",
|
||||||
|
toolsFile: "",
|
||||||
|
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder/example_tools2.yaml"},
|
||||||
|
toolsFolder: "",
|
||||||
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||||
|
wantWatchedFiles: map[string]bool{
|
||||||
|
"tools_folder/example_tools.yaml": true,
|
||||||
|
"tools_folder/example_tools2.yaml": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "multiple files in different levels",
|
||||||
|
toolsFile: "",
|
||||||
|
toolsFiles: []string{
|
||||||
|
"tools_folder/example_tools.yaml",
|
||||||
|
"tools_folder/special_tools/example_tools2.yaml"},
|
||||||
|
toolsFolder: "",
|
||||||
|
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder/special_tools": true},
|
||||||
|
wantWatchedFiles: map[string]bool{
|
||||||
|
"tools_folder/example_tools.yaml": true,
|
||||||
|
"tools_folder/special_tools/example_tools2.yaml": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "tools folder",
|
||||||
|
toolsFile: "",
|
||||||
|
toolsFiles: []string{},
|
||||||
|
toolsFolder: "tools_folder",
|
||||||
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||||
|
wantWatchedFiles: map[string]bool{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
gotWatchDirs, gotWatchedFiles := resolveWatcherInputs(tc.toolsFile, tc.toolsFiles, tc.toolsFolder)
|
||||||
|
|
||||||
|
normalizedGotWatchDirs := normalizeFilepaths(gotWatchDirs)
|
||||||
|
normalizedGotWatchedFiles := normalizeFilepaths(gotWatchedFiles)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tc.wantWatchDirs, normalizedGotWatchDirs); diff != "" {
|
||||||
|
t.Errorf("incorrect watchDirs: diff %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.wantWatchedFiles, normalizedGotWatchedFiles); diff != "" {
|
||||||
|
t.Errorf("incorrect watchedFiles: diff %v", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper function for testing file detection in dynamic reloading
|
||||||
|
func tmpFileWithCleanup(content []byte) (string, func(), error) {
|
||||||
|
f, err := os.CreateTemp("", "*")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
cleanup := func() { os.Remove(f.Name()) }
|
||||||
|
|
||||||
|
if _, err := f.Write(content); err != nil {
|
||||||
|
cleanup()
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
cleanup()
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return f.Name(), cleanup, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleEdit(t *testing.T) {
|
||||||
|
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancelCtx()
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
defer pw.Close()
|
||||||
|
defer pr.Close()
|
||||||
|
|
||||||
|
fileToWatch, cleanup, err := tmpFileWithCleanup([]byte("initial content"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error editing tools file %s", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup logger %s", err)
|
||||||
|
}
|
||||||
|
ctx = util.WithLogger(ctx, logger)
|
||||||
|
|
||||||
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup instrumentation %s", err)
|
||||||
|
}
|
||||||
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||||
|
|
||||||
|
mockServer := &server.Server{}
|
||||||
|
|
||||||
|
cleanFileToWatch := filepath.Clean(fileToWatch)
|
||||||
|
watchDir := filepath.Dir(cleanFileToWatch)
|
||||||
|
|
||||||
|
watchedFiles := map[string]bool{cleanFileToWatch: true}
|
||||||
|
watchDirs := map[string]bool{watchDir: true}
|
||||||
|
|
||||||
|
go watchChanges(ctx, watchDirs, watchedFiles, mockServer)
|
||||||
|
|
||||||
|
// escape backslash so regex doesn't fail on windows filepaths
|
||||||
|
regexEscapedPathFile := strings.ReplaceAll(cleanFileToWatch, `\`, `\\\\*\\`)
|
||||||
|
regexEscapedPathFile = path.Clean(regexEscapedPathFile)
|
||||||
|
|
||||||
|
regexEscapedPathDir := strings.ReplaceAll(watchDir, `\`, `\\\\*\\`)
|
||||||
|
regexEscapedPathDir = path.Clean(regexEscapedPathDir)
|
||||||
|
|
||||||
|
begunWatchingDir := regexp.MustCompile(fmt.Sprintf(`DEBUG "Added directory %s to watcher."`, regexEscapedPathDir))
|
||||||
|
res, err := testutils.WaitForString(ctx, begunWatchingDir, pr)
|
||||||
|
t.Log("log result: ", res)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("timeout or error waiting for watcher to start: %s, actually got: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(fileToWatch, []byte("modification"), 0777)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error writing to file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
detectedFileChange := regexp.MustCompile(fmt.Sprintf(`DEBUG "WRITE event detected in %s"`, regexEscapedPathFile))
|
||||||
|
res, err = testutils.WaitForString(ctx, detectedFileChange, pr)
|
||||||
|
t.Log("log result: ", res)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("timeout or error waiting for file to detect write: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrebuiltTools(t *testing.T) {
|
func TestPrebuiltTools(t *testing.T) {
|
||||||
alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres")
|
alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres")
|
||||||
bigquery_config, _ := prebuiltconfigs.Get("bigquery")
|
bigquery_config, _ := prebuiltconfigs.Get("bigquery")
|
||||||
|
|||||||
@@ -123,6 +123,9 @@ execute `toolbox` to start the server:
|
|||||||
```sh
|
```sh
|
||||||
./toolbox --tools-file "tools.yaml"
|
./toolbox --tools-file "tools.yaml"
|
||||||
```
|
```
|
||||||
|
{{< notice note >}}
|
||||||
|
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
||||||
terminate signal (`ctrl+c` on most platforms).
|
terminate signal (`ctrl+c` on most platforms).
|
||||||
|
|||||||
@@ -257,6 +257,9 @@ In this section, we will download Toolbox, configure our tools in a
|
|||||||
```bash
|
```bash
|
||||||
./toolbox --tools-file "tools.yaml"
|
./toolbox --tools-file "tools.yaml"
|
||||||
```
|
```
|
||||||
|
{{< notice note >}}
|
||||||
|
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
## Step 3: Connect your agent to Toolbox
|
## Step 3: Connect your agent to Toolbox
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,10 @@ When running with stdio, Toolbox will listen via stdio instead of acting as a
|
|||||||
remote HTTP server. Logs will be set to the `warn` level by default. `debug` and
|
remote HTTP server. Logs will be set to the `warn` level by default. `debug` and
|
||||||
`info` logs are not supported with stdio.
|
`info` logs are not supported with stdio.
|
||||||
|
|
||||||
|
{{< notice note >}}
|
||||||
|
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
### Connecting via HTTP
|
### Connecting via HTTP
|
||||||
|
|
||||||
Toolbox supports the HTTP transport protocol with and without SSE.
|
Toolbox supports the HTTP transport protocol with and without SSE.
|
||||||
|
|||||||
@@ -292,6 +292,9 @@ to use BigQuery, and then run the Toolbox server.
|
|||||||
```bash
|
```bash
|
||||||
./toolbox --tools-file "tools.yaml"
|
./toolbox --tools-file "tools.yaml"
|
||||||
```
|
```
|
||||||
|
{{< notice note >}}
|
||||||
|
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
## Step 3: Connect your agent to Toolbox
|
## Step 3: Connect your agent to Toolbox
|
||||||
|
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -14,6 +14,7 @@ require (
|
|||||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0
|
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0
|
||||||
github.com/couchbase/gocb/v2 v2.10.0
|
github.com/couchbase/gocb/v2 v2.10.0
|
||||||
github.com/couchbase/tools-common/http v1.0.9
|
github.com/couchbase/tools-common/http v1.0.9
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/go-chi/chi/v5 v5.2.2
|
github.com/go-chi/chi/v5 v5.2.2
|
||||||
github.com/go-chi/httplog/v2 v2.1.1
|
github.com/go-chi/httplog/v2 v2.1.1
|
||||||
github.com/go-chi/render v1.0.3
|
github.com/go-chi/render v1.0.3
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -765,6 +765,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
|||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
toolset, ok := s.toolsets[toolsetName]
|
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("toolset %q does not exist", toolsetName)
|
err = fmt.Errorf("toolset %q does not exist", toolsetName)
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
@@ -111,7 +111,7 @@ func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
tool, ok := s.tools[toolName]
|
tool, ok := s.ResourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
@@ -156,7 +156,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tool, ok := s.tools[toolName]
|
tool, ok := s.ResourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
@@ -167,7 +167,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
// Tool authentication
|
// Tool authentication
|
||||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||||
claimsFromAuth := make(map[string]map[string]any)
|
claimsFromAuth := make(map[string]map[string]any)
|
||||||
for _, aS := range s.authServices {
|
for _, aS := range s.ResourceMgr.GetAuthServiceMap() {
|
||||||
claims, err := aS.GetClaimsFromHeader(ctx, r.Header)
|
claims, err := aS.GetClaimsFromHeader(ctx, r.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
|
|||||||
@@ -147,14 +147,23 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
|||||||
t.Fatalf("unable to setup otel: %s", err)
|
t.Fatalf("unable to setup otel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(fakeVersionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create custom metrics: %s", err)
|
t.Fatalf("unable to create custom metrics: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sseManager := newSseManager(ctx)
|
sseManager := newSseManager(ctx)
|
||||||
|
|
||||||
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: tools, toolsets: toolsets}
|
resourceManager := NewResourceManager(nil, nil, tools, toolsets)
|
||||||
|
|
||||||
|
server := Server{
|
||||||
|
version: fakeVersionString,
|
||||||
|
logger: testLogger,
|
||||||
|
instrumentation: instrumentation,
|
||||||
|
sseManager: sseManager,
|
||||||
|
ResourceMgr: resourceManager,
|
||||||
|
}
|
||||||
|
|
||||||
var r chi.Router
|
var r chi.Router
|
||||||
switch router {
|
switch router {
|
||||||
case "api":
|
case "api":
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ type ServerConfig struct {
|
|||||||
TelemetryServiceName string
|
TelemetryServiceName string
|
||||||
// Stdio indicates if Toolbox is listening via MCP stdio.
|
// Stdio indicates if Toolbox is listening via MCP stdio.
|
||||||
Stdio bool
|
Stdio bool
|
||||||
|
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
||||||
|
DisableReload bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type logFormat string
|
type logFormat string
|
||||||
|
|||||||
@@ -479,12 +479,12 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
|
|||||||
}
|
}
|
||||||
return v, res, err
|
return v, res, err
|
||||||
default:
|
default:
|
||||||
toolset, ok := s.toolsets[toolsetName]
|
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("toolset does not exist")
|
err = fmt.Errorf("toolset does not exist")
|
||||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.tools, body)
|
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), body)
|
||||||
return "", res, err
|
return "", res, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -693,14 +693,22 @@ func TestStdioSession(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(fakeVersionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create custom metrics: %s", err)
|
t.Fatalf("unable to create custom metrics: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sseManager := newSseManager(ctx)
|
sseManager := newSseManager(ctx)
|
||||||
|
|
||||||
server := &Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: toolsMap, toolsets: toolsets}
|
resourceManager := NewResourceManager(nil, nil, toolsMap, toolsets)
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
version: fakeVersionString,
|
||||||
|
logger: testLogger,
|
||||||
|
instrumentation: instrumentation,
|
||||||
|
sseManager: sseManager,
|
||||||
|
ResourceMgr: resourceManager,
|
||||||
|
}
|
||||||
|
|
||||||
in := bufio.NewReader(pr)
|
in := bufio.NewReader(pr)
|
||||||
stdioSession := NewStdioSession(server, in, pw)
|
stdioSession := NewStdioSession(server, in, pw)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@@ -29,6 +30,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
@@ -42,26 +44,225 @@ type Server struct {
|
|||||||
listener net.Listener
|
listener net.Listener
|
||||||
root chi.Router
|
root chi.Router
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
instrumentation *Instrumentation
|
instrumentation *telemetry.Instrumentation
|
||||||
sseManager *sseManager
|
sseManager *sseManager
|
||||||
|
ResourceMgr *ResourceManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
|
||||||
|
type ResourceManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
sources map[string]sources.Source
|
sources map[string]sources.Source
|
||||||
authServices map[string]auth.AuthService
|
authServices map[string]auth.AuthService
|
||||||
tools map[string]tools.Tool
|
tools map[string]tools.Tool
|
||||||
toolsets map[string]tools.Toolset
|
toolsets map[string]tools.Toolset
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a Server object based on provided Config.
|
func NewResourceManager(
|
||||||
func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, error) {
|
sourcesMap map[string]sources.Source,
|
||||||
instrumentation, err := CreateTelemetryInstrumentation(cfg.Version)
|
authServicesMap map[string]auth.AuthService,
|
||||||
|
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
|
||||||
|
) *ResourceManager {
|
||||||
|
resourceMgr := &ResourceManager{
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
sources: sourcesMap,
|
||||||
|
authServices: authServicesMap,
|
||||||
|
tools: toolsMap,
|
||||||
|
toolsets: toolsetsMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
return resourceMgr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetSource(sourceName string) (sources.Source, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
source, ok := r.sources[sourceName]
|
||||||
|
return source, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthService, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
authService, ok := r.authServices[authServiceName]
|
||||||
|
return authService, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
tool, ok := r.tools[toolName]
|
||||||
|
return tool, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetToolset(toolsetName string) (tools.Toolset, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
toolset, ok := r.toolsets[toolsetName]
|
||||||
|
return toolset, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.sources = sourcesMap
|
||||||
|
r.authServices = authServicesMap
|
||||||
|
r.tools = toolsMap
|
||||||
|
r.toolsets = toolsetsMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.authServices
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||||
|
map[string]sources.Source,
|
||||||
|
map[string]auth.AuthService,
|
||||||
|
map[string]tools.Tool,
|
||||||
|
map[string]tools.Toolset,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
ctx = util.WithUserAgent(ctx, cfg.Version)
|
||||||
|
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize and validate the sources from configs
|
||||||
|
sourcesMap := make(map[string]sources.Source)
|
||||||
|
for name, sc := range cfg.SourceConfigs {
|
||||||
|
s, err := func() (sources.Source, error) {
|
||||||
|
childCtx, span := instrumentation.Tracer.Start(
|
||||||
|
ctx,
|
||||||
|
"toolbox/server/source/init",
|
||||||
|
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
||||||
|
trace.WithAttributes(attribute.String("source_name", name)),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
sourcesMap[name] = s
|
||||||
|
}
|
||||||
|
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
|
||||||
|
|
||||||
|
// initialize and validate the auth services from configs
|
||||||
|
authServicesMap := make(map[string]auth.AuthService)
|
||||||
|
for name, sc := range cfg.AuthServiceConfigs {
|
||||||
|
a, err := func() (auth.AuthService, error) {
|
||||||
|
_, span := instrumentation.Tracer.Start(
|
||||||
|
ctx,
|
||||||
|
"toolbox/server/auth/init",
|
||||||
|
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||||
|
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
a, err := sc.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
authServicesMap[name] = a
|
||||||
|
}
|
||||||
|
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
|
||||||
|
|
||||||
|
// initialize and validate the tools from configs
|
||||||
|
toolsMap := make(map[string]tools.Tool)
|
||||||
|
for name, tc := range cfg.ToolConfigs {
|
||||||
|
t, err := func() (tools.Tool, error) {
|
||||||
|
_, span := instrumentation.Tracer.Start(
|
||||||
|
ctx,
|
||||||
|
"toolbox/server/tool/init",
|
||||||
|
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
||||||
|
trace.WithAttributes(attribute.String("tool_name", name)),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
t, err := tc.Initialize(sourcesMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
toolsMap[name] = t
|
||||||
|
}
|
||||||
|
l.InfoContext(ctx, fmt.Sprintf("Initialized %d tools.", len(toolsMap)))
|
||||||
|
|
||||||
|
// create a default toolset that contains all tools
|
||||||
|
allToolNames := make([]string, 0, len(toolsMap))
|
||||||
|
for name := range toolsMap {
|
||||||
|
allToolNames = append(allToolNames, name)
|
||||||
|
}
|
||||||
|
if cfg.ToolsetConfigs == nil {
|
||||||
|
cfg.ToolsetConfigs = make(ToolsetConfigs)
|
||||||
|
}
|
||||||
|
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
||||||
|
|
||||||
|
// initialize and validate the toolsets from configs
|
||||||
|
toolsetsMap := make(map[string]tools.Toolset)
|
||||||
|
for name, tc := range cfg.ToolsetConfigs {
|
||||||
|
t, err := func() (tools.Toolset, error) {
|
||||||
|
_, span := instrumentation.Tracer.Start(
|
||||||
|
ctx,
|
||||||
|
"toolbox/server/toolset/init",
|
||||||
|
trace.WithAttributes(attribute.String("toolset_name", name)),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
t, err := tc.Initialize(cfg.Version, toolsMap)
|
||||||
|
if err != nil {
|
||||||
|
return tools.Toolset{}, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
|
||||||
|
}
|
||||||
|
return t, err
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
toolsetsMap[name] = t
|
||||||
|
}
|
||||||
|
l.InfoContext(ctx, fmt.Sprintf("Initialized %d toolsets.", len(toolsetsMap)))
|
||||||
|
|
||||||
|
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer returns a Server object based on provided Config.
|
||||||
|
func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||||
|
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/init")
|
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/init")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
l, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// set up http serving
|
// set up http serving
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
@@ -97,116 +298,18 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
|||||||
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
||||||
r.Use(httplog.RequestLogger(httpLogger))
|
r.Use(httplog.RequestLogger(httpLogger))
|
||||||
|
|
||||||
// initialize and validate the sources from configs
|
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||||
sourcesMap := make(map[string]sources.Source)
|
if err != nil {
|
||||||
for name, sc := range cfg.SourceConfigs {
|
return nil, fmt.Errorf("unable to initialize configs: %w", err)
|
||||||
s, err := func() (sources.Source, error) {
|
|
||||||
childCtx, span := instrumentation.Tracer.Start(
|
|
||||||
ctx,
|
|
||||||
"toolbox/server/source/init",
|
|
||||||
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
|
||||||
trace.WithAttributes(attribute.String("source_name", name)),
|
|
||||||
)
|
|
||||||
defer span.End()
|
|
||||||
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sourcesMap[name] = s
|
|
||||||
}
|
}
|
||||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
|
|
||||||
|
|
||||||
// initialize and validate the auth services from configs
|
|
||||||
authServicesMap := make(map[string]auth.AuthService)
|
|
||||||
for name, sc := range cfg.AuthServiceConfigs {
|
|
||||||
a, err := func() (auth.AuthService, error) {
|
|
||||||
_, span := instrumentation.Tracer.Start(
|
|
||||||
ctx,
|
|
||||||
"toolbox/server/auth/init",
|
|
||||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
|
||||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
|
||||||
)
|
|
||||||
defer span.End()
|
|
||||||
a, err := sc.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
|
|
||||||
}
|
|
||||||
return a, nil
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
authServicesMap[name] = a
|
|
||||||
}
|
|
||||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
|
|
||||||
|
|
||||||
// initialize and validate the tools from configs
|
|
||||||
toolsMap := make(map[string]tools.Tool)
|
|
||||||
for name, tc := range cfg.ToolConfigs {
|
|
||||||
t, err := func() (tools.Tool, error) {
|
|
||||||
_, span := instrumentation.Tracer.Start(
|
|
||||||
ctx,
|
|
||||||
"toolbox/server/tool/init",
|
|
||||||
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
|
||||||
trace.WithAttributes(attribute.String("tool_name", name)),
|
|
||||||
)
|
|
||||||
defer span.End()
|
|
||||||
t, err := tc.Initialize(sourcesMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
toolsMap[name] = t
|
|
||||||
}
|
|
||||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d tools.", len(toolsMap)))
|
|
||||||
|
|
||||||
// create a default toolset that contains all tools
|
|
||||||
allToolNames := make([]string, 0, len(toolsMap))
|
|
||||||
for name := range toolsMap {
|
|
||||||
allToolNames = append(allToolNames, name)
|
|
||||||
}
|
|
||||||
if cfg.ToolsetConfigs == nil {
|
|
||||||
cfg.ToolsetConfigs = make(ToolsetConfigs)
|
|
||||||
}
|
|
||||||
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
|
||||||
|
|
||||||
// initialize and validate the toolsets from configs
|
|
||||||
toolsetsMap := make(map[string]tools.Toolset)
|
|
||||||
for name, tc := range cfg.ToolsetConfigs {
|
|
||||||
t, err := func() (tools.Toolset, error) {
|
|
||||||
_, span := instrumentation.Tracer.Start(
|
|
||||||
ctx,
|
|
||||||
"toolbox/server/toolset/init",
|
|
||||||
trace.WithAttributes(attribute.String("toolset_name", name)),
|
|
||||||
)
|
|
||||||
defer span.End()
|
|
||||||
t, err := tc.Initialize(cfg.Version, toolsMap)
|
|
||||||
if err != nil {
|
|
||||||
return tools.Toolset{}, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
|
|
||||||
}
|
|
||||||
return t, err
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
toolsetsMap[name] = t
|
|
||||||
}
|
|
||||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d toolsets.", len(toolsetsMap)))
|
|
||||||
|
|
||||||
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
|
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
|
||||||
srv := &http.Server{Addr: addr, Handler: r}
|
srv := &http.Server{Addr: addr, Handler: r}
|
||||||
|
|
||||||
sseManager := newSseManager(ctx)
|
sseManager := newSseManager(ctx)
|
||||||
|
|
||||||
|
resourceManager := NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap)
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
version: cfg.Version,
|
version: cfg.Version,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
@@ -214,11 +317,7 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
|||||||
logger: l,
|
logger: l,
|
||||||
instrumentation: instrumentation,
|
instrumentation: instrumentation,
|
||||||
sseManager: sseManager,
|
sseManager: sseManager,
|
||||||
|
ResourceMgr: resourceManager,
|
||||||
sources: sourcesMap,
|
|
||||||
authServices: authServicesMap,
|
|
||||||
tools: toolsMap,
|
|
||||||
toolsets: toolsetsMap,
|
|
||||||
}
|
}
|
||||||
// control plane
|
// control plane
|
||||||
apiR, err := apiRouter(s)
|
apiR, err := apiRouter(s)
|
||||||
|
|||||||
@@ -23,9 +23,16 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServe(t *testing.T) {
|
func TestServe(t *testing.T) {
|
||||||
@@ -54,8 +61,16 @@ func TestServe(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %s", err)
|
t.Fatalf("unexpected error: %s", err)
|
||||||
}
|
}
|
||||||
|
ctx = util.WithLogger(ctx, testLogger)
|
||||||
|
|
||||||
s, err := server.NewServer(ctx, cfg, testLogger)
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||||
|
|
||||||
|
s, err := server.NewServer(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to initialize server: %v", err)
|
t.Fatalf("unable to initialize server: %v", err)
|
||||||
}
|
}
|
||||||
@@ -93,3 +108,67 @@ func TestServe(t *testing.T) {
|
|||||||
t.Fatalf("version missing from output: %q", got)
|
t.Fatalf("version missing from output: %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateServer(t *testing.T) {
|
||||||
|
ctx, err := testutils.ContextWithNewLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error setting up logger: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, port := "127.0.0.1", 5000
|
||||||
|
cfg := server.ServerConfig{
|
||||||
|
Version: "0.0.0",
|
||||||
|
Address: addr,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||||
|
|
||||||
|
s, err := server.NewServer(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error setting up server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newSources := map[string]sources.Source{
|
||||||
|
"example-source": &alloydbpg.Source{
|
||||||
|
Name: "example-alloydb-source",
|
||||||
|
Kind: "alloydb-postgres",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||||
|
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||||
|
newToolsets := map[string]tools.Toolset{
|
||||||
|
"example-toolset": {
|
||||||
|
Name: "example-toolset", Tools: []*tools.Tool{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error updating server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSource, _ := s.ResourceMgr.GetSource("example-source")
|
||||||
|
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
|
||||||
|
t.Errorf("error updating server, sources (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotAuthService, _ := s.ResourceMgr.GetAuthService("example-auth")
|
||||||
|
if diff := cmp.Diff(gotAuthService, newAuth["example-auth"]); diff != "" {
|
||||||
|
t.Errorf("error updating server, authServices (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotTool, _ := s.ResourceMgr.GetTool("example-tool")
|
||||||
|
if diff := cmp.Diff(gotTool, newTools["example-tool"]); diff != "" {
|
||||||
|
t.Errorf("error updating server, tools (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotToolset, _ := s.ResourceMgr.GetToolset("example-toolset")
|
||||||
|
if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"]); diff != "" {
|
||||||
|
t.Errorf("error updating server, toolset (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package server
|
package telemetry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,9 +15,12 @@
|
|||||||
package testutils
|
package testutils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
@@ -42,3 +45,63 @@ func ContextWithNewLogger() (context.Context, error) {
|
|||||||
}
|
}
|
||||||
return util.WithLogger(ctx, logger), nil
|
return util.WithLogger(ctx, logger), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WaitForString waits until the server logs a single line that matches the provided regex.
|
||||||
|
// returns the output of whatever the server sent so far.
|
||||||
|
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {
|
||||||
|
in := bufio.NewReader(pr)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// read lines in background, sending result of each read over a channel
|
||||||
|
// this allows us to use in.ReadString without blocking
|
||||||
|
type result struct {
|
||||||
|
s string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
output := make(chan result)
|
||||||
|
go func() {
|
||||||
|
defer close(output)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// if the context is canceled, the orig thread will send back the error
|
||||||
|
// so we can just exit the goroutine here
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// otherwise read a line from the output
|
||||||
|
s, err := in.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
output <- result{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
output <- result{s: s}
|
||||||
|
// if that last string matched, exit the goroutine
|
||||||
|
if re.MatchString(s) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// collect the output until the ctx is canceled, an error was hit,
|
||||||
|
// or match was found (which is indicated the channel is closed)
|
||||||
|
var sb strings.Builder
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// if ctx is done, return that error
|
||||||
|
return sb.String(), ctx.Err()
|
||||||
|
case o, ok := <-output:
|
||||||
|
if !ok {
|
||||||
|
// match was found!
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
if o.err != nil {
|
||||||
|
// error was found!
|
||||||
|
return sb.String(), o.err
|
||||||
|
}
|
||||||
|
sb.WriteString(o.s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DecodeJSON decodes a given reader into an interface using the json decoder.
|
// DecodeJSON decodes a given reader into an interface using the json decoder.
|
||||||
@@ -98,10 +99,25 @@ func WithLogger(ctx context.Context, logger log.Logger) context.Context {
|
|||||||
return context.WithValue(ctx, loggerKey, logger)
|
return context.WithValue(ctx, loggerKey, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggerFromContext retreives the logger or return an error
|
// LoggerFromContext retrieves the logger or return an error
|
||||||
func LoggerFromContext(ctx context.Context) (log.Logger, error) {
|
func LoggerFromContext(ctx context.Context) (log.Logger, error) {
|
||||||
if logger, ok := ctx.Value(loggerKey).(log.Logger); ok {
|
if logger, ok := ctx.Value(loggerKey).(log.Logger); ok {
|
||||||
return logger, nil
|
return logger, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve logger")
|
return nil, fmt.Errorf("unable to retrieve logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const instrumentationKey contextKey = "instrumentation"
|
||||||
|
|
||||||
|
// WithInstrumentation adds an instrumentation into the context as a value
|
||||||
|
func WithInstrumentation(ctx context.Context, instrumentation *telemetry.Instrumentation) context.Context {
|
||||||
|
return context.WithValue(ctx, instrumentationKey, instrumentation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstrumentationFromContext retrieves the instrumentation or return an error
|
||||||
|
func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation, error) {
|
||||||
|
if instrumentation, ok := ctx.Value(instrumentationKey).(*telemetry.Instrumentation); ok {
|
||||||
|
return instrumentation, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,7 +91,7 @@ func TestAlloyDBAINLToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
"cloud.google.com/go/alloydbconn"
|
"cloud.google.com/go/alloydbconn"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
@@ -157,7 +158,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
@@ -124,7 +125,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"cloud.google.com/go/bigtable"
|
"cloud.google.com/go/bigtable"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
@@ -103,7 +104,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
"cloud.google.com/go/cloudsqlconn"
|
"cloud.google.com/go/cloudsqlconn"
|
||||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,7 +152,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"cloud.google.com/go/cloudsqlconn"
|
"cloud.google.com/go/cloudsqlconn"
|
||||||
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,7 +139,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
"cloud.google.com/go/cloudsqlconn"
|
"cloud.google.com/go/cloudsqlconn"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
@@ -142,7 +143,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
|
|
||||||
"github.com/couchbase/gocb/v2"
|
"github.com/couchbase/gocb/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -128,7 +129,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,7 +79,7 @@ func TestDgraphToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
@@ -254,7 +255,7 @@ func TestHttpToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,7 +125,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,7 +116,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,7 +88,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
@@ -121,7 +122,7 @@ func TestPostgres(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
@@ -90,7 +91,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -15,13 +15,10 @@
|
|||||||
package tests
|
package tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
|
||||||
@@ -133,63 +130,3 @@ func (c *CmdExec) Close() {
|
|||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitForString waits until the server logs a single line that matches the provided regex.
|
|
||||||
// returns the output of whatever the server sent so far.
|
|
||||||
func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string, error) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
in := bufio.NewReader(c.Out)
|
|
||||||
|
|
||||||
// read lines in background, sending result of each read over a channel
|
|
||||||
// this allows us to use in.ReadString without blocking
|
|
||||||
type result struct {
|
|
||||||
s string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
output := make(chan result)
|
|
||||||
go func() {
|
|
||||||
defer close(output)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// if the context is canceled, the orig thread will send back the error
|
|
||||||
// so we can just exit the goroutine here
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
// otherwise read a line from the output
|
|
||||||
s, err := in.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
output <- result{err: err}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
output <- result{s: s}
|
|
||||||
// if that last string matched, exit the goroutine
|
|
||||||
if re.MatchString(s) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// collect the output until the ctx is canceled, an error was hit,
|
|
||||||
// or match was found (which is indicated the channel is closed)
|
|
||||||
var sb strings.Builder
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// if ctx is done, return that error
|
|
||||||
return sb.String(), ctx.Err()
|
|
||||||
case o, ok := <-output:
|
|
||||||
if !ok {
|
|
||||||
// match was found!
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
if o.err != nil {
|
|
||||||
// error was found!
|
|
||||||
return sb.String(), o.err
|
|
||||||
}
|
|
||||||
sb.WriteString(o.s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cloud.google.com/go/cloudsqlconn"
|
"cloud.google.com/go/cloudsqlconn"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunSourceConnection test for source connection
|
// RunSourceConnection test for source connection
|
||||||
@@ -57,7 +58,7 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
return fmt.Errorf("toolbox didn't start successfully: %s", err)
|
return fmt.Errorf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
database "cloud.google.com/go/spanner/admin/database/apiv1"
|
database "cloud.google.com/go/spanner/admin/database/apiv1"
|
||||||
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
|
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
@@ -141,7 +142,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,7 +144,7 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
"github.com/valkey-io/valkey-go"
|
"github.com/valkey-io/valkey-go"
|
||||||
)
|
)
|
||||||
@@ -93,7 +94,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user