feat: add address and port flags (#7)

Add flags for setting the address and port of the server.
This commit is contained in:
Kurtis Van Gent
2024-07-30 10:55:13 -05:00
committed by GitHub
parent e09ae30a90
commit df9ad9e33f
5 changed files with 118 additions and 45 deletions

View File

@@ -58,34 +58,38 @@ func Execute() {
// Command represents an invocation of the CLI.
type Command struct {
*cobra.Command
cfg server.Config
}
// NewCommand returns a Command object representing an invocation of the CLI.
func NewCommand() *Command {
rootCmd := &cobra.Command{
Use: "toolbox",
Version: versionString,
c := &Command{
Command: &cobra.Command{
Use: "toolbox",
Version: versionString,
},
}
c := &Command{
Command: rootCmd,
cfg: server.Config{},
}
flags := c.Flags()
flags.StringVarP(&c.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&c.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
// wrap RunE command so that we have access to original Command object
rootCmd.RunE = func(*cobra.Command, []string) error { return run(c) }
c.RunE = func(*cobra.Command, []string) error { return run(c) }
return c
}
func run(cmd *Command) error {
ctx := context.Background()
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
// run server
s := server.NewServer(cmd.cfg)
err := s.ListenAndServe(ctx)
if err != nil {
return fmt.Errorf("Error while serving: %w", err)
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
}
return nil

View File

@@ -12,19 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package cmd_test
package cmd
import (
"bytes"
_ "embed"
"io"
"os"
"strings"
"testing"
"github.com/googleapis/genai-toolbox/cmd"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/spf13/cobra"
)
func invokeCommand(args []string) (*Command, string, error) {
c := NewCommand()
// Keep the test output quiet
c.SilenceUsage = true
c.SilenceErrors = true
// Capture output
buf := new(bytes.Buffer)
c.SetOut(buf)
c.SetErr(buf)
c.SetArgs(args)
// Disable execute behavior
c.RunE = func(*cobra.Command, []string) error {
return nil
}
err := c.Execute()
return c, buf.String(), err
}
func TestVersion(t *testing.T) {
data, err := os.ReadFile("version.txt")
if err != nil {
@@ -32,24 +55,73 @@ func TestVersion(t *testing.T) {
}
want := strings.TrimSpace(string(data))
// run command with flag
b := bytes.NewBufferString("")
cmd := cmd.NewCommand()
cmd.SetArgs([]string{"--version"})
cmd.SetOut(b)
err = cmd.Execute()
_, got, err := invokeCommand([]string{"--version"})
if err != nil {
t.Fatalf("unable to execute command: %q", err)
t.Fatalf("error invoking command: %s", err)
}
out, err := io.ReadAll(b)
if err != nil {
t.Fatal(err)
}
got := string(out)
if !strings.Contains(got, want) {
t.Errorf("cli did not return correct version: want %q, got %q", want, got)
}
}
func TestFlags(t *testing.T) {
tcs := []struct {
desc string
args []string
want server.Config
}{
{
desc: "default values",
args: []string{},
want: server.Config{
Address: "127.0.0.1",
Port: 5000,
},
},
{
desc: "address short",
args: []string{"-a", "127.0.1.1"},
want: server.Config{
Address: "127.0.1.1",
Port: 5000,
},
},
{
desc: "address long",
args: []string{"--address", "0.0.0.0"},
want: server.Config{
Address: "0.0.0.0",
Port: 5000,
},
},
{
desc: "port short",
args: []string{"-p", "5052"},
want: server.Config{
Address: "127.0.0.1",
Port: 5052,
},
},
{
desc: "port long",
args: []string{"--port", "5050"},
want: server.Config{
Address: "127.0.0.1",
Port: 5050,
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
c, _, err := invokeCommand(tc.args)
if err != nil {
t.Fatalf("unexpected error invoking command: %s", err)
}
if c.cfg != tc.want {
t.Fatalf("got %v, want %v", c.cfg, tc.want)
}
})
}
}

View File

@@ -14,16 +14,8 @@
package server
type Config struct {
// Address is the address of the interface the server will listen on.
// address is the address of the interface the server will listen on.
Address string
// Port is the port the server will listen on.
Port string
}
func NewConfig() Config {
c := Config{
Address: "127.0.0.1",
Port: "5000",
}
return c
// port is the port the server will listen on.
Port int
}

View File

@@ -19,6 +19,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
@@ -32,7 +33,7 @@ type Server struct {
}
// NewServer returns a Server object based on provided Config.
func NewServer(conf Config) *Server {
func NewServer(cfg Config) *Server {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
@@ -41,7 +42,7 @@ func NewServer(conf Config) *Server {
})
s := &Server{
conf: conf,
conf: cfg,
router: r,
}
return s
@@ -52,7 +53,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
addr := net.JoinHostPort(s.conf.Address, s.conf.Port)
addr := net.JoinHostPort(s.conf.Address, strconv.Itoa(s.conf.Port))
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
l, err := lc.Listen(ctx, "tcp", addr)
if err != nil {

View File

@@ -17,6 +17,7 @@ package server_test
import (
"context"
"net"
"strconv"
"testing"
"time"
@@ -37,12 +38,15 @@ func tryDial(addr string, attempts int) bool {
return false
}
func TestListenAndServe(t *testing.T) {
func TestServe(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
host, port := "127.0.0.1", "5000"
cfg := server.NewConfig()
addr, port := "127.0.0.1", 5000
cfg := server.Config{
Address: addr,
Port: port,
}
s := server.NewServer(cfg)
// start server in background
@@ -55,7 +59,7 @@ func TestListenAndServe(t *testing.T) {
}
}()
if !tryDial(net.JoinHostPort(host, port), 10) {
if !tryDial(net.JoinHostPort(addr, strconv.Itoa(port)), 10) {
t.Fatalf("Unable to dial server!")
}