mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add address and port flags (#7)
Add flags for setting the address and port of the server.
This commit is contained in:
24
cmd/root.go
24
cmd/root.go
@@ -58,34 +58,38 @@ func Execute() {
|
||||
// Command represents an invocation of the CLI.
|
||||
type Command struct {
|
||||
*cobra.Command
|
||||
|
||||
cfg server.Config
|
||||
}
|
||||
|
||||
// NewCommand returns a Command object representing an invocation of the CLI.
|
||||
func NewCommand() *Command {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "toolbox",
|
||||
Version: versionString,
|
||||
c := &Command{
|
||||
Command: &cobra.Command{
|
||||
Use: "toolbox",
|
||||
Version: versionString,
|
||||
},
|
||||
}
|
||||
|
||||
c := &Command{
|
||||
Command: rootCmd,
|
||||
cfg: server.Config{},
|
||||
}
|
||||
flags := c.Flags()
|
||||
flags.StringVarP(&c.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
|
||||
flags.IntVarP(&c.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
rootCmd.RunE = func(*cobra.Command, []string) error { return run(c) }
|
||||
c.RunE = func(*cobra.Command, []string) error { return run(c) }
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func run(cmd *Command) error {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
// run server
|
||||
s := server.NewServer(cmd.cfg)
|
||||
err := s.ListenAndServe(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error while serving: %w", err)
|
||||
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
106
cmd/root_test.go
106
cmd/root_test.go
@@ -12,19 +12,42 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd_test
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/cmd"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func invokeCommand(args []string) (*Command, string, error) {
|
||||
c := NewCommand()
|
||||
|
||||
// Keep the test output quiet
|
||||
c.SilenceUsage = true
|
||||
c.SilenceErrors = true
|
||||
|
||||
// Capture output
|
||||
buf := new(bytes.Buffer)
|
||||
c.SetOut(buf)
|
||||
c.SetErr(buf)
|
||||
c.SetArgs(args)
|
||||
|
||||
// Disable execute behavior
|
||||
c.RunE = func(*cobra.Command, []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.Execute()
|
||||
|
||||
return c, buf.String(), err
|
||||
}
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
data, err := os.ReadFile("version.txt")
|
||||
if err != nil {
|
||||
@@ -32,24 +55,73 @@ func TestVersion(t *testing.T) {
|
||||
}
|
||||
want := strings.TrimSpace(string(data))
|
||||
|
||||
// run command with flag
|
||||
b := bytes.NewBufferString("")
|
||||
cmd := cmd.NewCommand()
|
||||
cmd.SetArgs([]string{"--version"})
|
||||
cmd.SetOut(b)
|
||||
|
||||
err = cmd.Execute()
|
||||
_, got, err := invokeCommand([]string{"--version"})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to execute command: %q", err)
|
||||
t.Fatalf("error invoking command: %s", err)
|
||||
}
|
||||
|
||||
out, err := io.ReadAll(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := string(out)
|
||||
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("cli did not return correct version: want %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlags(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
args []string
|
||||
want server.Config
|
||||
}{
|
||||
{
|
||||
desc: "default values",
|
||||
args: []string{},
|
||||
want: server.Config{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "address short",
|
||||
args: []string{"-a", "127.0.1.1"},
|
||||
want: server.Config{
|
||||
Address: "127.0.1.1",
|
||||
Port: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "address long",
|
||||
args: []string{"--address", "0.0.0.0"},
|
||||
want: server.Config{
|
||||
Address: "0.0.0.0",
|
||||
Port: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "port short",
|
||||
args: []string{"-p", "5052"},
|
||||
want: server.Config{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5052,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "port long",
|
||||
args: []string{"--port", "5050"},
|
||||
want: server.Config{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5050,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
c, _, err := invokeCommand(tc.args)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error invoking command: %s", err)
|
||||
}
|
||||
|
||||
if c.cfg != tc.want {
|
||||
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,16 +14,8 @@
|
||||
package server
|
||||
|
||||
type Config struct {
|
||||
// Address is the address of the interface the server will listen on.
|
||||
// address is the address of the interface the server will listen on.
|
||||
Address string
|
||||
// Port is the port the server will listen on.
|
||||
Port string
|
||||
}
|
||||
|
||||
func NewConfig() Config {
|
||||
c := Config{
|
||||
Address: "127.0.0.1",
|
||||
Port: "5000",
|
||||
}
|
||||
return c
|
||||
// port is the port the server will listen on.
|
||||
Port int
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -32,7 +33,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
func NewServer(conf Config) *Server {
|
||||
func NewServer(cfg Config) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
@@ -41,7 +42,7 @@ func NewServer(conf Config) *Server {
|
||||
})
|
||||
|
||||
s := &Server{
|
||||
conf: conf,
|
||||
conf: cfg,
|
||||
router: r,
|
||||
}
|
||||
return s
|
||||
@@ -52,7 +53,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
addr := net.JoinHostPort(s.conf.Address, s.conf.Port)
|
||||
addr := net.JoinHostPort(s.conf.Address, strconv.Itoa(s.conf.Port))
|
||||
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
|
||||
l, err := lc.Listen(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ package server_test
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -37,12 +38,15 @@ func tryDial(addr string, attempts int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestListenAndServe(t *testing.T) {
|
||||
func TestServe(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
host, port := "127.0.0.1", "5000"
|
||||
cfg := server.NewConfig()
|
||||
addr, port := "127.0.0.1", 5000
|
||||
cfg := server.Config{
|
||||
Address: addr,
|
||||
Port: port,
|
||||
}
|
||||
s := server.NewServer(cfg)
|
||||
|
||||
// start server in background
|
||||
@@ -55,7 +59,7 @@ func TestListenAndServe(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if !tryDial(net.JoinHostPort(host, port), 10) {
|
||||
if !tryDial(net.JoinHostPort(addr, strconv.Itoa(port)), 10) {
|
||||
t.Fatalf("Unable to dial server!")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user