Files
genai-toolbox/cmd/root.go
Kurtis Van Gent 66ab70f702 fix: handle shutdown gracefully (#178)
Adds logic to make the server shutdown gracefully, including better
respecting cancelled contexts and providing up to 10 seconds to finish
current connections.
2024-12-31 09:53:53 -07:00

237 lines
6.6 KiB
Go

// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cmd
import (
"context"
_ "embed"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
var (
// versionString indicates the version of this library.
//go:embed version.txt
versionString string
// metadataString indicates additional build or distribution metadata.
metadataString string
)
func init() {
versionString = semanticVersion()
}
// semanticVersion returns the version of the CLI including a compile-time metadata.
func semanticVersion() string {
v := strings.TrimSpace(versionString)
if metadataString != "" {
v += "+" + metadataString
}
return v
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := NewCommand().Execute(); err != nil {
exit := 1
os.Exit(exit)
}
}
// Command represents an invocation of the CLI.
type Command struct {
*cobra.Command
cfg server.ServerConfig
logger log.Logger
tools_file string
outStream io.Writer
errStream io.Writer
}
// NewCommand returns a Command object representing an invocation of the CLI.
func NewCommand(opts ...Option) *Command {
out := os.Stdout
err := os.Stderr
baseCmd := &cobra.Command{
Use: "toolbox",
Version: versionString,
SilenceErrors: true,
}
cmd := &Command{
Command: baseCmd,
outStream: out,
errStream: err,
}
for _, o := range opts {
o(cmd)
}
// Set server version
cmd.cfg.Version = versionString
// set baseCmd out and err the same as cmd.
baseCmd.SetOut(cmd.outStream)
baseCmd.SetErr(cmd.errStream)
flags := cmd.Flags()
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
flags.StringVar(&cmd.tools_file, "tools_file", "tools.yaml", "File path specifying the tool configuration.")
flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
// wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
return cmd
}
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthSourceConfigs `yaml:"authSources"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(raw []byte) (ToolsFile, error) {
var toolsFile ToolsFile
// Parse contents
err := yaml.Unmarshal(raw, &toolsFile)
if err != nil {
return toolsFile, err
}
return toolsFile, nil
}
func run(cmd *Command) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
// watch for sigterm / sigint signals
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
go func() {
var s os.Signal
select {
case <-ctx.Done():
// this should only happen when the context supplied when testing is canceled
return
case s = <-signals:
}
switch s {
case syscall.SIGINT:
cmd.logger.DebugContext(ctx, "Received SIGINT signal to shutdown.")
case syscall.SIGTERM:
cmd.logger.DebugContext(ctx, "Sending SIGTERM signal to shutdown.")
}
cancel()
}()
// Handle logger separately from config
switch strings.ToLower(cmd.cfg.LoggingFormat.String()) {
case "json":
logger, err := log.NewStructuredLogger(cmd.outStream, cmd.errStream, cmd.cfg.LogLevel.String())
if err != nil {
return fmt.Errorf("unable to initialize logger: %w", err)
}
cmd.logger = logger
case "standard":
logger, err := log.NewStdLogger(cmd.outStream, cmd.errStream, cmd.cfg.LogLevel.String())
if err != nil {
return fmt.Errorf("unable to initialize logger: %w", err)
}
cmd.logger = logger
default:
return fmt.Errorf("logging format invalid.")
}
// Read tool file contents
buf, err := os.ReadFile(cmd.tools_file)
if err != nil {
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
toolsFile, err := parseToolsFile(buf)
cmd.cfg.SourceConfigs, cmd.cfg.AuthSourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs = toolsFile.Sources, toolsFile.AuthSources, toolsFile.Tools, toolsFile.Toolsets
if err != nil {
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// start server
s, err := server.NewServer(ctx, cmd.cfg, cmd.logger)
if err != nil {
errMsg := fmt.Errorf("toolbox failed to initialize: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
err = s.Listen(ctx)
if err != nil {
errMsg := fmt.Errorf("toolbox failed to start listener: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.logger.InfoContext(ctx, "Server ready to serve!")
// run server in background
srvErr := make(chan error)
go func() {
defer close(srvErr)
err = s.Serve()
if err != nil {
srvErr <- err
}
}()
// wait for either the server to error out or the command's context to be canceled
select {
case err := <-srvErr:
if err != nil {
errMsg := fmt.Errorf("toolbox crashed with the following error: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
case <-ctx.Done():
shutdownContext, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd.logger.WarnContext(shutdownContext, "Shutting down gracefully...")
err := s.Shutdown(shutdownContext)
if err == context.DeadlineExceeded {
return fmt.Errorf("graceful shutdown timed out... forcing exit.")
}
}
return nil
}