mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 06:48:04 -05:00
feat: Add YAML configuration support
Add support for persistent configuration via YAML files. Users can now specify common options in a config file while maintaining the ability to override with CLI flags. Currently supports core options like model, temperature, and pattern settings. - Add --config flag for specifying YAML config path - Support standard option precedence (CLI > YAML > defaults) - Add type-safe YAML parsing with reflection - Add tests for YAML config functionality
This commit is contained in:
188
cli/flags.go
188
cli/flags.go
@@ -6,29 +6,31 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jessevdk/go-flags"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
|
||||
type Flags struct {
|
||||
Pattern string `short:"p" long:"pattern" description:"Choose a pattern from the available patterns" default:""`
|
||||
Pattern string `short:"p" long:"pattern" yaml:"pattern" description:"Choose a pattern from the available patterns" default:""`
|
||||
PatternVariables map[string]string `short:"v" long:"variable" description:"Values for pattern variables, e.g. -v=#role:expert -v=#points:30"`
|
||||
Context string `short:"C" long:"context" description:"Choose a context from the available contexts" default:""`
|
||||
Session string `long:"session" description:"Choose a session from the available sessions"`
|
||||
Attachments []string `short:"a" long:"attachment" description:"Attachment path or URL (e.g. for OpenAI image recognition messages)"`
|
||||
Setup bool `short:"S" long:"setup" description:"Run setup for all reconfigurable parts of fabric"`
|
||||
Temperature float64 `short:"t" long:"temperature" description:"Set temperature" default:"0.7"`
|
||||
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
|
||||
Stream bool `short:"s" long:"stream" description:"Stream"`
|
||||
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
|
||||
Raw bool `short:"r" long:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
|
||||
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
|
||||
Temperature float64 `short:"t" long:"temperature" yaml:"temperature" description:"Set temperature" default:"0.7"`
|
||||
TopP float64 `short:"T" long:"topp" yaml:"topp" description:"Set top P" default:"0.9"`
|
||||
Stream bool `short:"s" long:"stream" yaml:"stream" description:"Stream"`
|
||||
PresencePenalty float64 `short:"P" long:"presencepenalty" yaml:"presencepenalty" description:"Set presence penalty" default:"0.0"`
|
||||
Raw bool `short:"r" long:"raw" yaml:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
|
||||
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" yaml:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
|
||||
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
|
||||
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
|
||||
ListAllContexts bool `short:"x" long:"listcontexts" description:"List all contexts"`
|
||||
@@ -36,8 +38,8 @@ type Flags struct {
|
||||
UpdatePatterns bool `short:"U" long:"updatepatterns" description:"Update patterns"`
|
||||
Message string `hidden:"true" description:"Messages to send to chat"`
|
||||
Copy bool `short:"c" long:"copy" description:"Copy to clipboard"`
|
||||
Model string `short:"m" long:"model" description:"Choose model"`
|
||||
ModelContextLength int `long:"modelContextLength" description:"Model context length (only affects ollama)"`
|
||||
Model string `short:"m" long:"model" yaml:"model" description:"Choose model"`
|
||||
ModelContextLength int `long:"modelContextLength" yaml:"modelContextLength" description:"Model context length (only affects ollama)"`
|
||||
Output string `short:"o" long:"output" description:"Output to file" default:""`
|
||||
OutputSession bool `long:"output-session" description:"Output the entire session (also a temporary one) to the output file"`
|
||||
LatestPatterns string `short:"n" long:"latest" description:"Number of latest patterns to list" default:"0"`
|
||||
@@ -49,7 +51,7 @@ type Flags struct {
|
||||
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
|
||||
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
|
||||
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
|
||||
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
|
||||
Seed int `short:"e" long:"seed" yaml:"seed" description:"Seed to be used for LMM generation"`
|
||||
WipeContext string `short:"w" long:"wipecontext" description:"Wipe context"`
|
||||
WipeSession string `short:"W" long:"wipesession" description:"Wipe session"`
|
||||
PrintContext string `long:"printcontext" description:"Print context"`
|
||||
@@ -59,37 +61,175 @@ type Flags struct {
|
||||
DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"`
|
||||
Serve bool `long:"serve" description:"Serve the Fabric Rest API"`
|
||||
ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"`
|
||||
Config string `long:"config" description:"Path to YAML config file"`
|
||||
Version bool `long:"version" description:"Print current version"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
|
||||
func Debugf(format string, a ...interface{}) {
|
||||
if debug {
|
||||
fmt.Printf("DEBUG: "+format, a...)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Init Initialize flags. returns a Flags struct and an error
|
||||
func Init() (ret *Flags, err error) {
|
||||
ret = &Flags{}
|
||||
parser := flags.NewParser(ret, flags.Default)
|
||||
var args []string
|
||||
if args, err = parser.Parse(); err != nil {
|
||||
return
|
||||
// Track which yaml-configured flags were set on CLI
|
||||
usedFlags := make(map[string]bool)
|
||||
args := os.Args[1:]
|
||||
|
||||
// Get list of fields that have yaml tags, could be in yaml config
|
||||
yamlFields := make(map[string]bool)
|
||||
t := reflect.TypeOf(Flags{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
if yamlTag := t.Field(i).Tag.Get("yaml"); yamlTag != "" {
|
||||
yamlFields[yamlTag] = true
|
||||
//Debugf("Found yaml-configured field: %s\n", yamlTag)
|
||||
}
|
||||
}
|
||||
|
||||
// Scan args for that are provided by cli and might be in yaml
|
||||
for _, arg := range args {
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flag := strings.TrimPrefix(arg, "--")
|
||||
if i := strings.Index(flag, "="); i > 0 {
|
||||
flag = flag[:i]
|
||||
}
|
||||
if yamlFields[flag] {
|
||||
usedFlags[flag] = true
|
||||
Debugf("CLI flag used: %s\n", flag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse CLI flags first
|
||||
ret = &Flags{}
|
||||
parser := flags.NewParser(ret, flags.Default)
|
||||
if _, err = parser.Parse(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If config specified, load and apply YAML for unused flags
|
||||
if ret.Config != "" {
|
||||
yamlFlags, err := loadYAMLConfig(ret.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply YAML values where CLI flags weren't used
|
||||
flagsVal := reflect.ValueOf(ret).Elem()
|
||||
yamlVal := reflect.ValueOf(yamlFlags).Elem()
|
||||
flagsType := flagsVal.Type()
|
||||
|
||||
for i := 0; i < flagsType.NumField(); i++ {
|
||||
field := flagsType.Field(i)
|
||||
if yamlTag := field.Tag.Get("yaml"); yamlTag != "" {
|
||||
if !usedFlags[yamlTag] {
|
||||
flagField := flagsVal.Field(i)
|
||||
yamlField := yamlVal.Field(i)
|
||||
if flagField.CanSet() {
|
||||
if yamlField.Type() != flagField.Type() {
|
||||
if err := assignWithConversion(flagField, yamlField); err != nil {
|
||||
Debugf("Type conversion failed for %s: %v\n", yamlTag, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
flagField.Set(yamlField)
|
||||
}
|
||||
Debugf("Applied YAML value for %s: %v\n", yamlTag, yamlField.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle stdin and messages
|
||||
info, _ := os.Stdin.Stat()
|
||||
pipedToStdin := (info.Mode() & os.ModeCharDevice) == 0
|
||||
|
||||
//custom message
|
||||
if len(args) > 0 {
|
||||
ret.Message = AppendMessage(ret.Message, args[len(args)-1])
|
||||
ret.Message = AppendMessage(ret.Message, args[len(args)-1])
|
||||
}
|
||||
|
||||
// takes input from stdin if it exists, otherwise takes input from args (the last argument)
|
||||
if pipedToStdin {
|
||||
var pipedMessage string
|
||||
if pipedMessage, err = readStdin(); err != nil {
|
||||
return
|
||||
}
|
||||
ret.Message = AppendMessage(ret.Message, pipedMessage)
|
||||
var pipedMessage string
|
||||
if pipedMessage, err = readStdin(); err != nil {
|
||||
return
|
||||
}
|
||||
ret.Message = AppendMessage(ret.Message, pipedMessage)
|
||||
}
|
||||
return
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
func assignWithConversion(targetField, sourceField reflect.Value) error {
|
||||
switch targetField.Kind() {
|
||||
case reflect.Float64:
|
||||
if sourceField.Kind() == reflect.Int || sourceField.Kind() == reflect.Float32 {
|
||||
targetField.SetFloat(float64(sourceField.Convert(reflect.TypeOf(float64(0))).Float()))
|
||||
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
|
||||
return nil
|
||||
}
|
||||
case reflect.Int:
|
||||
if sourceField.Kind() == reflect.Float64 || sourceField.Kind() == reflect.Float32 {
|
||||
targetField.SetInt(int64(sourceField.Convert(reflect.TypeOf(int64(0))).Int()))
|
||||
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
|
||||
return nil
|
||||
}
|
||||
case reflect.String:
|
||||
if sourceField.Kind() == reflect.Interface {
|
||||
if str, ok := sourceField.Interface().(string); ok {
|
||||
targetField.SetString(str)
|
||||
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
case reflect.Bool:
|
||||
if sourceField.Kind() == reflect.Interface {
|
||||
if b, ok := sourceField.Interface().(bool); ok {
|
||||
targetField.SetBool(b)
|
||||
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported conversion: %s to %s", sourceField.Type(), targetField.Type())
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
func loadYAMLConfig(configPath string) (*Flags, error) {
|
||||
absPath, err := common.GetAbsolutePath(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid config path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("config file not found: %s", absPath)
|
||||
}
|
||||
return nil, fmt.Errorf("error reading config file: %w", err)
|
||||
}
|
||||
|
||||
// Use the existing Flags struct for YAML unmarshal
|
||||
config := &Flags{}
|
||||
if err := yaml.Unmarshal(data, config); err != nil {
|
||||
return nil, fmt.Errorf("error parsing config file: %w", err)
|
||||
}
|
||||
|
||||
Debugf("Config: %v\n", config)
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
||||
// readStdin reads from stdin and returns the input as a string or an error
|
||||
func readStdin() (ret string, err error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
@@ -87,3 +87,80 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
func TestInitWithYAMLConfig(t *testing.T) {
|
||||
// Create a temporary YAML config file
|
||||
configContent := `
|
||||
temperature: 0.9
|
||||
model: gpt-4
|
||||
pattern: analyze
|
||||
stream: true
|
||||
`
|
||||
tmpfile, err := os.CreateTemp("", "config.*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write([]byte(configContent)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test 1: Basic YAML loading
|
||||
t.Run("Load YAML config", func(t *testing.T) {
|
||||
oldArgs := os.Args
|
||||
defer func() { os.Args = oldArgs }()
|
||||
os.Args = []string{"cmd", "--config", tmpfile.Name()}
|
||||
|
||||
flags, err := Init()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0.9, flags.Temperature)
|
||||
assert.Equal(t, "gpt-4", flags.Model)
|
||||
assert.Equal(t, "analyze", flags.Pattern)
|
||||
assert.True(t, flags.Stream)
|
||||
})
|
||||
|
||||
// Test 2: CLI overrides YAML
|
||||
t.Run("CLI overrides YAML", func(t *testing.T) {
|
||||
oldArgs := os.Args
|
||||
defer func() { os.Args = oldArgs }()
|
||||
os.Args = []string{"cmd", "--config", tmpfile.Name(), "--temperature", "0.7", "--model", "gpt-3.5-turbo"}
|
||||
|
||||
flags, err := Init()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0.7, flags.Temperature)
|
||||
assert.Equal(t, "gpt-3.5-turbo", flags.Model)
|
||||
assert.Equal(t, "analyze", flags.Pattern) // unchanged from YAML
|
||||
assert.True(t, flags.Stream) // unchanged from YAML
|
||||
})
|
||||
|
||||
// Test 3: Invalid YAML config
|
||||
t.Run("Invalid YAML config", func(t *testing.T) {
|
||||
badConfig := `
|
||||
temperature: "not a float"
|
||||
model: 123 # should be string
|
||||
`
|
||||
badfile, err := os.CreateTemp("", "bad-config.*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(badfile.Name())
|
||||
|
||||
if _, err := badfile.Write([]byte(badConfig)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := badfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldArgs := os.Args
|
||||
defer func() { os.Args = oldArgs }()
|
||||
os.Args = []string{"cmd", "--config", badfile.Name()}
|
||||
|
||||
_, err = Init()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -6,14 +6,15 @@ toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/anaskhan96/soup v1.2.5
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/gabriel-vasile/mimetype v1.4.6
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-git/go-git/v5 v5.12.0
|
||||
github.com/go-shiori/go-readability v0.0.0-20241012063810-92284fa8a71f
|
||||
github.com/google/generative-ai-go v0.18.0
|
||||
github.com/jessevdk/go-flags v1.6.1
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/liushuangls/go-anthropic/v2 v2.11.0
|
||||
github.com/ollama/ollama v0.4.1
|
||||
github.com/otiai10/copy v1.14.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
@@ -22,6 +23,7 @@ require (
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/text v0.20.0
|
||||
google.golang.org/api v0.205.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -35,7 +37,6 @@ require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.1.2 // indirect
|
||||
github.com/andybalholm/cascadia v1.3.2 // indirect
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 // indirect
|
||||
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
|
||||
github.com/bytedance/sonic v1.12.4 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.1 // indirect
|
||||
@@ -58,7 +59,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.3 // indirect
|
||||
github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/generative-ai-go v0.18.0 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
|
||||
3
go.sum
3
go.sum
@@ -158,8 +158,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.11.0 h1:YKyxDWQNaKPPgtLCgBH+JqzuznNWw8ZqQVeSdQNDMds=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.11.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
@@ -360,6 +358,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
Reference in New Issue
Block a user