mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-10 07:05:04 -05:00
437 lines
14 KiB
Go
437 lines
14 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cmd
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"fmt"
|
|
"maps"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
|
"github.com/googleapis/genai-toolbox/internal/cli"
|
|
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
|
"github.com/googleapis/genai-toolbox/internal/cli/skills"
|
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
"github.com/googleapis/genai-toolbox/internal/server"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
"github.com/spf13/cobra"
|
|
|
|
// Import prompt packages for side effect of registration
|
|
_ "github.com/googleapis/genai-toolbox/cmd/imports"
|
|
)
|
|
|
|
var (
|
|
// versionString stores the full semantic version, including build metadata.
|
|
versionString string
|
|
// versionNum indicates the numerical part fo the version
|
|
//go:embed version.txt
|
|
versionNum string
|
|
// metadataString indicates additional build or distribution metadata.
|
|
buildType string = "dev" // should be one of "dev", "binary", or "container"
|
|
// commitSha is the git commit it was built from
|
|
commitSha string
|
|
)
|
|
|
|
func init() {
|
|
versionString = semanticVersion()
|
|
}
|
|
|
|
// semanticVersion returns the version of the CLI including a compile-time metadata.
|
|
func semanticVersion() string {
|
|
metadataStrings := []string{buildType, runtime.GOOS, runtime.GOARCH}
|
|
if commitSha != "" {
|
|
metadataStrings = append(metadataStrings, commitSha)
|
|
}
|
|
v := strings.TrimSpace(versionNum) + "+" + strings.Join(metadataStrings, ".")
|
|
return v
|
|
}
|
|
|
|
// Execute adds all child commands to the root command and sets flags appropriately.
|
|
// This is called by main.main(). It only needs to happen once to the rootCmd.
|
|
func Execute() {
|
|
// Initialize cli deps
|
|
opts := cli.NewToolboxOptions()
|
|
|
|
if err := NewCommand(opts).Execute(); err != nil {
|
|
exit := 1
|
|
os.Exit(exit)
|
|
}
|
|
}
|
|
|
|
// NewCommand returns a Command object representing an invocation of the CLI.
|
|
func NewCommand(opts *cli.ToolboxOptions) *cobra.Command {
|
|
cmd := &cobra.Command{
|
|
Use: "toolbox",
|
|
Version: versionString,
|
|
SilenceErrors: true,
|
|
}
|
|
|
|
// Do not print Usage on runtime error
|
|
cmd.SilenceUsage = true
|
|
|
|
// Set server version
|
|
opts.Cfg.Version = versionString
|
|
|
|
// set baseCmd in, out and err the same as cmd.
|
|
cmd.SetIn(opts.IOStreams.In)
|
|
cmd.SetOut(opts.IOStreams.Out)
|
|
cmd.SetErr(opts.IOStreams.ErrOut)
|
|
|
|
// setup flags that are common across all commands
|
|
cli.PersistentFlags(cmd, opts)
|
|
|
|
flags := cmd.Flags()
|
|
|
|
flags.StringVarP(&opts.Cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
|
|
flags.IntVarP(&opts.Cfg.Port, "port", "p", 5000, "Port the server will listen on.")
|
|
|
|
flags.StringVar(&opts.ToolsFile, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
|
|
// deprecate tools_file
|
|
_ = flags.MarkDeprecated("tools_file", "please use --tools-file instead")
|
|
flags.BoolVar(&opts.Cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
|
flags.BoolVar(&opts.Cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
|
|
flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
|
|
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
|
flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
|
flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
|
|
|
// wrap RunE command so that we have access to original Command object
|
|
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd, opts) }
|
|
|
|
// Register subcommands for tool invocation
|
|
cmd.AddCommand(invoke.NewCommand(opts))
|
|
// Register subcommands for skill generation
|
|
cmd.AddCommand(skills.NewCommand(opts))
|
|
|
|
return cmd
|
|
}
|
|
|
|
func handleDynamicReload(ctx context.Context, toolsFile cli.ToolsFile, s *server.Server) error {
|
|
logger, err := util.LoggerFromContext(ctx)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
|
|
logger.WarnContext(ctx, errMsg.Error())
|
|
return err
|
|
}
|
|
|
|
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
|
|
func validateReloadEdits(
|
|
ctx context.Context, toolsFile cli.ToolsFile,
|
|
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
|
|
) {
|
|
logger, err := util.LoggerFromContext(ctx)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
instrumentation, err := util.InstrumentationFromContext(ctx)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
logger.DebugContext(ctx, "Attempting to parse and validate reloaded tools file.")
|
|
|
|
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/reload")
|
|
defer span.End()
|
|
|
|
reloadedConfig := server.ServerConfig{
|
|
Version: versionString,
|
|
SourceConfigs: toolsFile.Sources,
|
|
AuthServiceConfigs: toolsFile.AuthServices,
|
|
EmbeddingModelConfigs: toolsFile.EmbeddingModels,
|
|
ToolConfigs: toolsFile.Tools,
|
|
ToolsetConfigs: toolsFile.Toolsets,
|
|
PromptConfigs: toolsFile.Prompts,
|
|
}
|
|
|
|
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
|
|
logger.WarnContext(ctx, errMsg.Error())
|
|
return nil, nil, nil, nil, nil, nil, nil, err
|
|
}
|
|
|
|
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
|
}
|
|
|
|
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
|
|
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server) {
|
|
logger, err := util.LoggerFromContext(ctx)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
w, err := fsnotify.NewWatcher()
|
|
if err != nil {
|
|
logger.WarnContext(ctx, "error setting up new watcher %s", err)
|
|
return
|
|
}
|
|
|
|
defer w.Close()
|
|
|
|
watchingFolder := false
|
|
var folderToWatch string
|
|
|
|
// if watchedFiles is empty, indicates that user passed entire folder instead
|
|
if len(watchedFiles) == 0 {
|
|
watchingFolder = true
|
|
|
|
// validate that watchDirs only has single element
|
|
if len(watchDirs) > 1 {
|
|
logger.WarnContext(ctx, "error setting watcher, expected single tools folder if no file(s) are defined.")
|
|
return
|
|
}
|
|
|
|
for onlyKey := range watchDirs {
|
|
folderToWatch = onlyKey
|
|
break
|
|
}
|
|
}
|
|
|
|
for dir := range watchDirs {
|
|
err := w.Add(dir)
|
|
if err != nil {
|
|
logger.WarnContext(ctx, fmt.Sprintf("Error adding path %s to watcher: %s", dir, err))
|
|
break
|
|
}
|
|
logger.DebugContext(ctx, fmt.Sprintf("Added directory %s to watcher.", dir))
|
|
}
|
|
|
|
// debounce timer is used to prevent multiple writes triggering multiple reloads
|
|
debounceDelay := 100 * time.Millisecond
|
|
debounce := time.NewTimer(1 * time.Minute)
|
|
debounce.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.DebugContext(ctx, "file watcher context cancelled")
|
|
return
|
|
case err, ok := <-w.Errors:
|
|
if !ok {
|
|
logger.WarnContext(ctx, "file watcher was closed unexpectedly")
|
|
return
|
|
}
|
|
if err != nil {
|
|
logger.WarnContext(ctx, "file watcher error %s", err)
|
|
return
|
|
}
|
|
|
|
case e, ok := <-w.Events:
|
|
if !ok {
|
|
logger.WarnContext(ctx, "file watcher already closed")
|
|
return
|
|
}
|
|
|
|
// only check for events which indicate user saved a new tools file
|
|
// multiple operations checked due to various file update methods across editors
|
|
if !e.Has(fsnotify.Write | fsnotify.Create | fsnotify.Rename) {
|
|
continue
|
|
}
|
|
|
|
cleanedFilename := filepath.Clean(e.Name)
|
|
logger.DebugContext(ctx, fmt.Sprintf("%s event detected in %s", e.Op, cleanedFilename))
|
|
|
|
folderChanged := watchingFolder &&
|
|
(strings.HasSuffix(cleanedFilename, ".yaml") || strings.HasSuffix(cleanedFilename, ".yml"))
|
|
|
|
if folderChanged || watchedFiles[cleanedFilename] {
|
|
// indicates the write event is on a relevant file
|
|
debounce.Reset(debounceDelay)
|
|
}
|
|
|
|
case <-debounce.C:
|
|
debounce.Stop()
|
|
var reloadedToolsFile cli.ToolsFile
|
|
|
|
if watchingFolder {
|
|
logger.DebugContext(ctx, "Reloading tools folder.")
|
|
reloadedToolsFile, err = cli.LoadAndMergeToolsFolder(ctx, folderToWatch)
|
|
if err != nil {
|
|
logger.WarnContext(ctx, "error loading tools folder %s", err)
|
|
continue
|
|
}
|
|
} else {
|
|
logger.DebugContext(ctx, "Reloading tools file(s).")
|
|
reloadedToolsFile, err = cli.LoadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles)))
|
|
if err != nil {
|
|
logger.WarnContext(ctx, "error loading tools files %s", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
err = handleDynamicReload(ctx, reloadedToolsFile, s)
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("unable to parse reloaded tools file at %q: %w", reloadedToolsFile, err)
|
|
logger.WarnContext(ctx, errMsg.Error())
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func resolveWatcherInputs(toolsFile string, toolsFiles []string, toolsFolder string) (map[string]bool, map[string]bool) {
|
|
var relevantFiles []string
|
|
|
|
// map for efficiently checking if a file is relevant
|
|
watchedFiles := make(map[string]bool)
|
|
|
|
// dirs that will be added to watcher (fsnotify prefers watching directory then filtering for file)
|
|
watchDirs := make(map[string]bool)
|
|
|
|
if len(toolsFiles) > 0 {
|
|
relevantFiles = toolsFiles
|
|
} else if toolsFolder != "" {
|
|
watchDirs[filepath.Clean(toolsFolder)] = true
|
|
} else {
|
|
relevantFiles = []string{toolsFile}
|
|
}
|
|
|
|
// extract parent dir for relevant files and dedup
|
|
for _, f := range relevantFiles {
|
|
cleanFile := filepath.Clean(f)
|
|
watchedFiles[cleanFile] = true
|
|
watchDirs[filepath.Dir(cleanFile)] = true
|
|
}
|
|
|
|
return watchDirs, watchedFiles
|
|
}
|
|
|
|
func run(cmd *cobra.Command, opts *cli.ToolboxOptions) error {
|
|
ctx, cancel := context.WithCancel(cmd.Context())
|
|
defer cancel()
|
|
|
|
// watch for sigterm / sigint signals
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
|
go func(sCtx context.Context) {
|
|
var s os.Signal
|
|
select {
|
|
case <-sCtx.Done():
|
|
// this should only happen when the context supplied when testing is canceled
|
|
return
|
|
case s = <-signals:
|
|
}
|
|
switch s {
|
|
case syscall.SIGINT:
|
|
opts.Logger.DebugContext(sCtx, "Received SIGINT signal to shutdown.")
|
|
case syscall.SIGTERM:
|
|
opts.Logger.DebugContext(sCtx, "Sending SIGTERM signal to shutdown.")
|
|
}
|
|
cancel()
|
|
}(ctx)
|
|
|
|
ctx, shutdown, err := opts.Setup(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = shutdown(ctx)
|
|
}()
|
|
|
|
isCustomConfigured, err := opts.LoadConfig(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// start server
|
|
s, err := server.NewServer(ctx, opts.Cfg)
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("toolbox failed to initialize: %w", err)
|
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
|
return errMsg
|
|
}
|
|
|
|
// run server in background
|
|
srvErr := make(chan error)
|
|
if opts.Cfg.Stdio {
|
|
go func() {
|
|
defer close(srvErr)
|
|
err = s.ServeStdio(ctx, opts.IOStreams.In, opts.IOStreams.Out)
|
|
if err != nil {
|
|
srvErr <- err
|
|
}
|
|
}()
|
|
} else {
|
|
err = s.Listen(ctx)
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("toolbox failed to start listener: %w", err)
|
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
|
return errMsg
|
|
}
|
|
opts.Logger.InfoContext(ctx, "Server ready to serve!")
|
|
if opts.Cfg.UI {
|
|
opts.Logger.InfoContext(ctx, fmt.Sprintf("Toolbox UI is up and running at: http://%s:%d/ui", opts.Cfg.Address, opts.Cfg.Port))
|
|
}
|
|
|
|
go func() {
|
|
defer close(srvErr)
|
|
err = s.Serve(ctx)
|
|
if err != nil {
|
|
srvErr <- err
|
|
}
|
|
}()
|
|
}
|
|
|
|
if isCustomConfigured && !opts.Cfg.DisableReload {
|
|
watchDirs, watchedFiles := resolveWatcherInputs(opts.ToolsFile, opts.ToolsFiles, opts.ToolsFolder)
|
|
// start watching the file(s) or folder for changes to trigger dynamic reloading
|
|
go watchChanges(ctx, watchDirs, watchedFiles, s)
|
|
}
|
|
|
|
// wait for either the server to error out or the command's context to be canceled
|
|
select {
|
|
case err := <-srvErr:
|
|
if err != nil {
|
|
errMsg := fmt.Errorf("toolbox crashed with the following error: %w", err)
|
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
|
return errMsg
|
|
}
|
|
case <-ctx.Done():
|
|
shutdownContext, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
opts.Logger.WarnContext(shutdownContext, "Shutting down gracefully...")
|
|
err := s.Shutdown(shutdownContext)
|
|
if err == context.DeadlineExceeded {
|
|
return fmt.Errorf("graceful shutdown timed out... forcing exit")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|