Compare commits

...

3 Commits

Author SHA1 Message Date
github-actions[bot]
12284ad3db Update version to v1.4.122 and commit 2024-12-14 15:02:50 +00:00
Eugen Eisler
f180e8fc6b Merge pull request #1201 from mattjoyce/feature/config-yaml
feat: Add YAML configuration support
2024-12-14 20:31:42 +05:30
Matt Joyce
89153dd235 feat: Add YAML configuration support
Add support for persistent configuration via YAML files. Users can now specify
common options in a config file while maintaining the ability to override with
CLI flags. Currently supports core options like model, temperature, and pattern
settings.

- Add --config flag for specifying YAML config path
- Support standard option precedence (CLI > YAML > defaults)
- Add type-safe YAML parsing with reflection
- Add tests for YAML config functionality
2024-12-14 14:37:12 +11:00
7 changed files with 237 additions and 28 deletions

View File

@@ -6,29 +6,31 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"
"github.com/jessevdk/go-flags"
goopenai "github.com/sashabaranov/go-openai"
"golang.org/x/text/language"
"gopkg.in/yaml.v2"
"github.com/danielmiessler/fabric/common"
)
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
type Flags struct {
Pattern string `short:"p" long:"pattern" description:"Choose a pattern from the available patterns" default:""`
Pattern string `short:"p" long:"pattern" yaml:"pattern" description:"Choose a pattern from the available patterns" default:""`
PatternVariables map[string]string `short:"v" long:"variable" description:"Values for pattern variables, e.g. -v=#role:expert -v=#points:30"`
Context string `short:"C" long:"context" description:"Choose a context from the available contexts" default:""`
Session string `long:"session" description:"Choose a session from the available sessions"`
Attachments []string `short:"a" long:"attachment" description:"Attachment path or URL (e.g. for OpenAI image recognition messages)"`
Setup bool `short:"S" long:"setup" description:"Run setup for all reconfigurable parts of fabric"`
Temperature float64 `short:"t" long:"temperature" description:"Set temperature" default:"0.7"`
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
Raw bool `short:"r" long:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
Temperature float64 `short:"t" long:"temperature" yaml:"temperature" description:"Set temperature" default:"0.7"`
TopP float64 `short:"T" long:"topp" yaml:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" yaml:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" yaml:"presencepenalty" description:"Set presence penalty" default:"0.0"`
Raw bool `short:"r" long:"raw" yaml:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" yaml:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
ListAllContexts bool `short:"x" long:"listcontexts" description:"List all contexts"`
@@ -36,8 +38,8 @@ type Flags struct {
UpdatePatterns bool `short:"U" long:"updatepatterns" description:"Update patterns"`
Message string `hidden:"true" description:"Messages to send to chat"`
Copy bool `short:"c" long:"copy" description:"Copy to clipboard"`
Model string `short:"m" long:"model" description:"Choose model"`
ModelContextLength int `long:"modelContextLength" description:"Model context length (only affects ollama)"`
Model string `short:"m" long:"model" yaml:"model" description:"Choose model"`
ModelContextLength int `long:"modelContextLength" yaml:"modelContextLength" description:"Model context length (only affects ollama)"`
Output string `short:"o" long:"output" description:"Output to file" default:""`
OutputSession bool `long:"output-session" description:"Output the entire session (also a temporary one) to the output file"`
LatestPatterns string `short:"n" long:"latest" description:"Number of latest patterns to list" default:"0"`
@@ -49,7 +51,7 @@ type Flags struct {
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
Seed int `short:"e" long:"seed" yaml:"seed" description:"Seed to be used for LMM generation"`
WipeContext string `short:"w" long:"wipecontext" description:"Wipe context"`
WipeSession string `short:"W" long:"wipesession" description:"Wipe session"`
PrintContext string `long:"printcontext" description:"Print context"`
@@ -59,27 +61,97 @@ type Flags struct {
DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"`
Serve bool `long:"serve" description:"Serve the Fabric Rest API"`
ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"`
Config string `long:"config" description:"Path to YAML config file"`
Version bool `long:"version" description:"Print current version"`
}
var debug = false
func Debugf(format string, a ...interface{}) {
if debug {
fmt.Printf("DEBUG: "+format, a...)
}
}
// Init Initialize flags. returns a Flags struct and an error
func Init() (ret *Flags, err error) {
ret = &Flags{}
parser := flags.NewParser(ret, flags.Default)
var args []string
if args, err = parser.Parse(); err != nil {
return
// Track which yaml-configured flags were set on CLI
usedFlags := make(map[string]bool)
args := os.Args[1:]
// Get list of fields that have yaml tags, could be in yaml config
yamlFields := make(map[string]bool)
t := reflect.TypeOf(Flags{})
for i := 0; i < t.NumField(); i++ {
if yamlTag := t.Field(i).Tag.Get("yaml"); yamlTag != "" {
yamlFields[yamlTag] = true
//Debugf("Found yaml-configured field: %s\n", yamlTag)
}
}
// Scan args for that are provided by cli and might be in yaml
for _, arg := range args {
if strings.HasPrefix(arg, "--") {
flag := strings.TrimPrefix(arg, "--")
if i := strings.Index(flag, "="); i > 0 {
flag = flag[:i]
}
if yamlFields[flag] {
usedFlags[flag] = true
Debugf("CLI flag used: %s\n", flag)
}
}
}
// Parse CLI flags first
ret = &Flags{}
parser := flags.NewParser(ret, flags.Default)
if _, err = parser.Parse(); err != nil {
return nil, err
}
// If config specified, load and apply YAML for unused flags
if ret.Config != "" {
yamlFlags, err := loadYAMLConfig(ret.Config)
if err != nil {
return nil, err
}
// Apply YAML values where CLI flags weren't used
flagsVal := reflect.ValueOf(ret).Elem()
yamlVal := reflect.ValueOf(yamlFlags).Elem()
flagsType := flagsVal.Type()
for i := 0; i < flagsType.NumField(); i++ {
field := flagsType.Field(i)
if yamlTag := field.Tag.Get("yaml"); yamlTag != "" {
if !usedFlags[yamlTag] {
flagField := flagsVal.Field(i)
yamlField := yamlVal.Field(i)
if flagField.CanSet() {
if yamlField.Type() != flagField.Type() {
if err := assignWithConversion(flagField, yamlField); err != nil {
Debugf("Type conversion failed for %s: %v\n", yamlTag, err)
continue
}
} else {
flagField.Set(yamlField)
}
Debugf("Applied YAML value for %s: %v\n", yamlTag, yamlField.Interface())
}
}
}
}
}
// Handle stdin and messages
info, _ := os.Stdin.Stat()
pipedToStdin := (info.Mode() & os.ModeCharDevice) == 0
//custom message
if len(args) > 0 {
ret.Message = AppendMessage(ret.Message, args[len(args)-1])
}
// takes input from stdin if it exists, otherwise takes input from args (the last argument)
if pipedToStdin {
var pipedMessage string
if pipedMessage, err = readStdin(); err != nil {
@@ -87,7 +159,68 @@ func Init() (ret *Flags, err error) {
}
ret.Message = AppendMessage(ret.Message, pipedMessage)
}
return
return ret, nil
}
func assignWithConversion(targetField, sourceField reflect.Value) error {
switch targetField.Kind() {
case reflect.Float64:
if sourceField.Kind() == reflect.Int || sourceField.Kind() == reflect.Float32 {
targetField.SetFloat(float64(sourceField.Convert(reflect.TypeOf(float64(0))).Float()))
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
case reflect.Int:
if sourceField.Kind() == reflect.Float64 || sourceField.Kind() == reflect.Float32 {
targetField.SetInt(int64(sourceField.Convert(reflect.TypeOf(int64(0))).Int()))
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
case reflect.String:
if sourceField.Kind() == reflect.Interface {
if str, ok := sourceField.Interface().(string); ok {
targetField.SetString(str)
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
}
case reflect.Bool:
if sourceField.Kind() == reflect.Interface {
if b, ok := sourceField.Interface().(bool); ok {
targetField.SetBool(b)
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
}
}
return fmt.Errorf("unsupported conversion: %s to %s", sourceField.Type(), targetField.Type())
}
func loadYAMLConfig(configPath string) (*Flags, error) {
absPath, err := common.GetAbsolutePath(configPath)
if err != nil {
return nil, fmt.Errorf("invalid config path: %w", err)
}
data, err := os.ReadFile(absPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("config file not found: %s", absPath)
}
return nil, fmt.Errorf("error reading config file: %w", err)
}
// Use the existing Flags struct for YAML unmarshal
config := &Flags{}
if err := yaml.Unmarshal(data, config); err != nil {
return nil, fmt.Errorf("error parsing config file: %w", err)
}
Debugf("Config: %v\n", config)
return config, nil
}
// readStdin reads from stdin and returns the input as a string or an error

View File

@@ -87,3 +87,80 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}
func TestInitWithYAMLConfig(t *testing.T) {
// Create a temporary YAML config file
configContent := `
temperature: 0.9
model: gpt-4
pattern: analyze
stream: true
`
tmpfile, err := os.CreateTemp("", "config.*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write([]byte(configContent)); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}
// Test 1: Basic YAML loading
t.Run("Load YAML config", func(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", tmpfile.Name()}
flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, 0.9, flags.Temperature)
assert.Equal(t, "gpt-4", flags.Model)
assert.Equal(t, "analyze", flags.Pattern)
assert.True(t, flags.Stream)
})
// Test 2: CLI overrides YAML
t.Run("CLI overrides YAML", func(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", tmpfile.Name(), "--temperature", "0.7", "--model", "gpt-3.5-turbo"}
flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, 0.7, flags.Temperature)
assert.Equal(t, "gpt-3.5-turbo", flags.Model)
assert.Equal(t, "analyze", flags.Pattern) // unchanged from YAML
assert.True(t, flags.Stream) // unchanged from YAML
})
// Test 3: Invalid YAML config
t.Run("Invalid YAML config", func(t *testing.T) {
badConfig := `
temperature: "not a float"
model: 123 # should be string
`
badfile, err := os.CreateTemp("", "bad-config.*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(badfile.Name())
if _, err := badfile.Write([]byte(badConfig)); err != nil {
t.Fatal(err)
}
if err := badfile.Close(); err != nil {
t.Fatal(err)
}
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", badfile.Name()}
_, err = Init()
assert.Error(t, err)
})
}

6
go.mod
View File

@@ -6,14 +6,15 @@ toolchain go1.23.1
require (
github.com/anaskhan96/soup v1.2.5
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4
github.com/atotto/clipboard v0.1.4
github.com/gabriel-vasile/mimetype v1.4.6
github.com/gin-gonic/gin v1.10.0
github.com/go-git/go-git/v5 v5.12.0
github.com/go-shiori/go-readability v0.0.0-20241012063810-92284fa8a71f
github.com/google/generative-ai-go v0.18.0
github.com/jessevdk/go-flags v1.6.1
github.com/joho/godotenv v1.5.1
github.com/liushuangls/go-anthropic/v2 v2.11.0
github.com/ollama/ollama v0.4.1
github.com/otiai10/copy v1.14.0
github.com/pkg/errors v0.9.1
@@ -22,6 +23,7 @@ require (
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.20.0
google.golang.org/api v0.205.0
gopkg.in/yaml.v2 v2.4.0
)
require (
@@ -35,7 +37,6 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.1.2 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 // indirect
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
github.com/bytedance/sonic v1.12.4 // indirect
github.com/bytedance/sonic/loader v0.2.1 // indirect
@@ -58,7 +59,6 @@ require (
github.com/goccy/go-json v0.10.3 // indirect
github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/generative-ai-go v0.18.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect

3
go.sum
View File

@@ -158,8 +158,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/liushuangls/go-anthropic/v2 v2.11.0 h1:YKyxDWQNaKPPgtLCgBH+JqzuznNWw8ZqQVeSdQNDMds=
github.com/liushuangls/go-anthropic/v2 v2.11.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
@@ -360,6 +358,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -154,9 +154,6 @@ schema = 3
[mod."github.com/leodido/go-urn"]
version = "v1.4.0"
hash = "sha256-Q6kplWkY37Tzy6GOme3Wut40jFK4Izun+ij/BJvcEu0="
[mod."github.com/liushuangls/go-anthropic/v2"]
version = "v2.11.0"
hash = "sha256-VvQ6RT8qcP19mRzBtFKh19czlRk5obHzh1NVs3z/Gkc="
[mod."github.com/mattn/go-isatty"]
version = "v0.0.20"
hash = "sha256-qhw9hWtU5wnyFyuMbKx+7RB8ckQaFQ8D+8GKPkN3HHQ="
@@ -280,6 +277,9 @@ schema = 3
[mod."gopkg.in/warnings.v0"]
version = "v0.1.2"
hash = "sha256-ATVL9yEmgYbkJ1DkltDGRn/auGAjqGOfjQyBYyUo8s8="
[mod."gopkg.in/yaml.v2"]
version = "v2.4.0"
hash = "sha256-uVEGglIedjOIGZzHW4YwN1VoRSTK8o0eGZqzd+TNdd0="
[mod."gopkg.in/yaml.v3"]
version = "v3.0.1"
hash = "sha256-FqL9TKYJ0XkNwJFnq9j0VvJ5ZUU1RvH/52h/f5bkYAU="

View File

@@ -1 +1 @@
"1.4.121"
"1.4.122"

View File

@@ -1,3 +1,3 @@
package main
var version = "v1.4.121"
var version = "v1.4.122"