mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-13 15:34:59 -05:00
feat: simplify setup logic
This commit is contained in:
30
cli/cli.go
30
cli/cli.go
@@ -14,7 +14,7 @@ import (
|
||||
func Cli() (message string, err error) {
|
||||
var currentFlags *Flags
|
||||
if currentFlags, err = Init(); err != nil {
|
||||
// we need to reset error, because we want to show double help messages
|
||||
// we need to reset error, because we don't want to show double help messages
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
@@ -24,23 +24,23 @@ func Cli() (message string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||
|
||||
// if the setup flag is set, run the setup function
|
||||
if currentFlags.Setup {
|
||||
_ = db.Configure()
|
||||
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
|
||||
_ = fabricDb.Configure()
|
||||
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
|
||||
return
|
||||
}
|
||||
|
||||
var fabric *core.Fabric
|
||||
if err = db.Configure(); err != nil {
|
||||
if err = fabricDb.Configure(); err != nil {
|
||||
fmt.Println("init is failed, run start the setup procedure", err)
|
||||
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if fabric, err = core.NewFabric(db); err != nil {
|
||||
if fabric, err = core.NewFabric(fabricDb); err != nil {
|
||||
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
||||
return
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
@@ -72,7 +72,7 @@ func Cli() (message string, err error) {
|
||||
|
||||
// if the list patterns flag is set, run the list all patterns function
|
||||
if currentFlags.ListPatterns {
|
||||
err = db.Patterns.ListNames()
|
||||
err = fabricDb.Patterns.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,13 +84,13 @@ func Cli() (message string, err error) {
|
||||
|
||||
// if the list all contexts flag is set, run the list all contexts function
|
||||
if currentFlags.ListAllContexts {
|
||||
err = db.Contexts.ListNames()
|
||||
err = fabricDb.Contexts.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
// if the list all sessions flag is set, run the list all sessions function
|
||||
if currentFlags.ListAllSessions {
|
||||
err = db.Sessions.ListNames()
|
||||
err = fabricDb.Sessions.ListNames()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -129,17 +129,17 @@ func Cli() (message string, err error) {
|
||||
}
|
||||
|
||||
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
||||
ret = core.NewFabricForSetup(db)
|
||||
instance := core.NewFabricForSetup(db)
|
||||
|
||||
if err = ret.Setup(); err != nil {
|
||||
if err = instance.Setup(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !skipUpdatePatterns {
|
||||
if err = ret.PopulateDB(); err != nil {
|
||||
if err = instance.PopulateDB(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ret = instance
|
||||
return
|
||||
}
|
||||
|
||||
23
cli/cli_test.go
Normal file
23
cli/cli_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCli(t *testing.T) {
|
||||
message, err := Cli()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, message)
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
mockDB := db.NewDb(os.TempDir())
|
||||
|
||||
fabric, err := Setup(mockDB, false)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, fabric)
|
||||
}
|
||||
85
cli/flags_test.go
Normal file
85
cli/flags_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
args := []string{"--copy"}
|
||||
expectedFlags := &Flags{Copy: true}
|
||||
oldArgs := os.Args
|
||||
defer func() { os.Args = oldArgs }()
|
||||
os.Args = append([]string{"cmd"}, args...)
|
||||
|
||||
flags, err := Init()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedFlags.Copy, flags.Copy)
|
||||
}
|
||||
|
||||
func TestReadStdin(t *testing.T) {
|
||||
input := "test input"
|
||||
stdin := ioutil.NopCloser(strings.NewReader(input))
|
||||
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
|
||||
content, err := ReadStdin(stdin)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != input {
|
||||
t.Fatalf("expected %q, got %q", input, content)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStdin function assuming it's part of `cli` package
|
||||
func ReadStdin(reader io.ReadCloser) (string, error) {
|
||||
defer reader.Close()
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := buf.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func TestBuildChatOptions(t *testing.T) {
|
||||
flags := &Flags{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
}
|
||||
|
||||
expectedOptions := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
func TestBuildChatRequest(t *testing.T) {
|
||||
flags := &Flags{
|
||||
Context: "test-context",
|
||||
Session: "test-session",
|
||||
Pattern: "test-pattern",
|
||||
Message: "test-message",
|
||||
}
|
||||
|
||||
expectedRequest := &common.ChatRequest{
|
||||
ContextName: "test-context",
|
||||
SessionName: "test-session",
|
||||
PatternName: "test-pattern",
|
||||
Message: "test-message",
|
||||
}
|
||||
request := flags.BuildChatRequest()
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
}
|
||||
Reference in New Issue
Block a user