feat: simplify setup logic

This commit is contained in:
Eugen Eisler
2024-08-22 21:45:36 +02:00
parent 6996278c8f
commit 4b3afb3c8e
6 changed files with 148 additions and 33 deletions

View File

@@ -14,7 +14,7 @@ import (
func Cli() (message string, err error) {
var currentFlags *Flags
if currentFlags, err = Init(); err != nil {
// we need to reset error, because we want to show double help messages
// we need to reset error, because we don't want to show double help messages
err = nil
return
}
@@ -24,23 +24,23 @@ func Cli() (message string, err error) {
return
}
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
// if the setup flag is set, run the setup function
if currentFlags.Setup {
_ = db.Configure()
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
_ = fabricDb.Configure()
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
return
}
var fabric *core.Fabric
if err = db.Configure(); err != nil {
if err = fabricDb.Configure(); err != nil {
fmt.Println("init is failed, run start the setup procedure", err)
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
return
}
} else {
if fabric, err = core.NewFabric(db); err != nil {
if fabric, err = core.NewFabric(fabricDb); err != nil {
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
return
}
@@ -64,7 +64,7 @@ func Cli() (message string, err error) {
return
}
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
return
}
return
@@ -72,7 +72,7 @@ func Cli() (message string, err error) {
// if the list patterns flag is set, run the list all patterns function
if currentFlags.ListPatterns {
err = db.Patterns.ListNames()
err = fabricDb.Patterns.ListNames()
return
}
@@ -84,13 +84,13 @@ func Cli() (message string, err error) {
// if the list all contexts flag is set, run the list all contexts function
if currentFlags.ListAllContexts {
err = db.Contexts.ListNames()
err = fabricDb.Contexts.ListNames()
return
}
// if the list all sessions flag is set, run the list all sessions function
if currentFlags.ListAllSessions {
err = db.Sessions.ListNames()
err = fabricDb.Sessions.ListNames()
return
}
@@ -129,17 +129,17 @@ func Cli() (message string, err error) {
}
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
ret = core.NewFabricForSetup(db)
instance := core.NewFabricForSetup(db)
if err = ret.Setup(); err != nil {
if err = instance.Setup(); err != nil {
return
}
if !skipUpdatePatterns {
if err = ret.PopulateDB(); err != nil {
if err = instance.PopulateDB(); err != nil {
return
}
}
ret = instance
return
}

23
cli/cli_test.go Normal file
View File

@@ -0,0 +1,23 @@
package cli
import (
"os"
"testing"
"github.com/danielmiessler/fabric/db"
"github.com/stretchr/testify/assert"
)
func TestCli(t *testing.T) {
message, err := Cli()
assert.NoError(t, err)
assert.Empty(t, message)
}
func TestSetup(t *testing.T) {
mockDB := db.NewDb(os.TempDir())
fabric, err := Setup(mockDB, false)
assert.Error(t, err)
assert.Nil(t, fabric)
}

85
cli/flags_test.go Normal file
View File

@@ -0,0 +1,85 @@
package cli
import (
"bytes"
"io"
"io/ioutil"
"os"
"strings"
"testing"
"github.com/danielmiessler/fabric/common"
"github.com/stretchr/testify/assert"
)
func TestInit(t *testing.T) {
args := []string{"--copy"}
expectedFlags := &Flags{Copy: true}
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = append([]string{"cmd"}, args...)
flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, expectedFlags.Copy, flags.Copy)
}
func TestReadStdin(t *testing.T) {
input := "test input"
stdin := ioutil.NopCloser(strings.NewReader(input))
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
content, err := ReadStdin(stdin)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != input {
t.Fatalf("expected %q, got %q", input, content)
}
}
// ReadStdin function assuming it's part of `cli` package
func ReadStdin(reader io.ReadCloser) (string, error) {
defer reader.Close()
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(reader)
if err != nil {
return "", err
}
return buf.String(), nil
}
func TestBuildChatOptions(t *testing.T) {
flags := &Flags{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}
expectedOptions := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}
func TestBuildChatRequest(t *testing.T) {
flags := &Flags{
Context: "test-context",
Session: "test-session",
Pattern: "test-pattern",
Message: "test-message",
}
expectedRequest := &common.ChatRequest{
ContextName: "test-context",
SessionName: "test-session",
PatternName: "test-pattern",
Message: "test-message",
}
request := flags.BuildChatRequest()
assert.Equal(t, expectedRequest, request)
}