diff --git a/cli/cli.go b/cli/cli.go index c4843951..a6c9d4eb 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/core" "github.com/danielmiessler/fabric/plugins/ai" "github.com/danielmiessler/fabric/plugins/db/fsdb" @@ -213,7 +214,11 @@ func Cli(version string) (err error) { } var session *fsdb.Session - chatReq := currentFlags.BuildChatRequest(strings.Join(os.Args[1:], " ")) + var chatReq *common.ChatRequest + if chatReq, err = currentFlags.BuildChatRequest(strings.Join(os.Args[1:], " ")); err != nil { + return + } + if chatReq.Language == "" { chatReq.Language = registry.Language.DefaultLanguage.Value } diff --git a/cli/flags.go b/cli/flags.go index 91fabbf4..fe59dc05 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -4,8 +4,10 @@ import ( "bufio" "errors" "fmt" + goopenai "github.com/sashabaranov/go-openai" "io" "os" + "strings" "github.com/danielmiessler/fabric/common" "github.com/jessevdk/go-flags" @@ -18,8 +20,7 @@ type Flags struct { PatternVariables map[string]string `short:"v" long:"variable" description:"Values for pattern variables, e.g. -v=#role:expert -v=#points:30"` Context string `short:"C" long:"context" description:"Choose a context from the available contexts" default:""` Session string `long:"session" description:"Choose a session from the available sessions"` - Attachment string `short:"a" long:"attachment" description:"Attachment path or URL" default:""` - AttachmentType string `long:"--attachment-type" description:"Attachment with explicit mimetype" default:""` + Attachments []string `short:"i" long:"image" description:"Attachment path or URL" default:""` Setup bool `short:"S" long:"setup" description:"Run setup for all reconfigurable parts of fabric"` Temperature float64 `short:"t" long:"temperature" description:"Set temperature" default:"0.7"` TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"` @@ -32,7 +33,7 @@ type Flags struct { ListAllContexts bool `short:"x" long:"listcontexts" description:"List all contexts"` ListAllSessions bool `short:"X" long:"listsessions" description:"List all sessions"` UpdatePatterns bool `short:"U" long:"updatepatterns" description:"Update patterns"` - Message string `hidden:"true" description:"Message to send to chat"` + Message string `hidden:"true" description:"Messages to send to chat"` Copy bool `short:"c" long:"copy" description:"Copy to clipboard"` Model string `short:"m" long:"model" description:"Choose model"` Output string `short:"o" long:"output" description:"Output to file" default:""` @@ -115,18 +116,63 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { return } -func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest) { +func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest, err error) { ret = &common.ChatRequest{ ContextName: o.Context, SessionName: o.Session, PatternName: o.Pattern, PatternVariables: o.PatternVariables, - Message: o.Message, Meta: Meta, } + + if o.Attachments == nil || len(o.Attachments) > 0 { + if o.Message != "" { + ret.Message = &goopenai.ChatCompletionMessage{ + Role: goopenai.ChatMessageRoleUser, + Content: strings.TrimSpace(o.Message), + } + } + } else { + ret.Message = &goopenai.ChatCompletionMessage{ + Role: goopenai.ChatMessageRoleUser, + } + + if o.Message != "" { + ret.Message.MultiContent = append(ret.Message.MultiContent, goopenai.ChatMessagePart{ + Type: goopenai.ChatMessagePartTypeText, + Text: strings.TrimSpace(o.Message), + }) + } + + for _, attachmentValue := range o.Attachments { + var attachment *common.Attachment + if attachment, err = common.NewAttachment(attachmentValue); err != nil { + return + } + url := attachment.URL + if url == nil { + var base64Image string + if base64Image, err = attachment.Base64Content(); err != nil { + return + } + var mimeType string + if mimeType, err = attachment.ResolveType(); err != nil { + return + } + dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Image) + url = &dataURL + } + ret.Message.MultiContent = append(ret.Message.MultiContent, goopenai.ChatMessagePart{ + Type: goopenai.ChatMessagePartTypeImageURL, + ImageURL: &goopenai.ChatMessageImageURL{ + URL: *url, + }, + }) + } + } + if o.Language != "" { - langTag, err := language.Parse(o.Language) - if err == nil { + if langTag, langErr := language.Parse(o.Language); langErr == nil { ret.Language = langTag.String() } } @@ -143,6 +189,6 @@ func (o *Flags) AppendMessage(message string) { } func (o *Flags) IsChatRequest() (ret bool) { - ret = o.Message != "" || o.Context != "" + ret = o.Message != "" || o.Context != "" || o.Session != "" || o.Pattern != "" || len(o.Attachments) > 0 return } diff --git a/cli/flags_test.go b/cli/flags_test.go index a2df1afa..93e230ea 100644 --- a/cli/flags_test.go +++ b/cli/flags_test.go @@ -100,7 +100,7 @@ func TestBuildChatRequest(t *testing.T) { ContextName: "test-context", SessionName: "test-session", PatternName: "test-pattern", - Message: "test-message", + Messages: "test-message", Meta: "test", } request := flags.BuildChatRequest("test") diff --git a/common/attachment.go b/common/attachment.go index f7d09644..f1e35e8a 100644 --- a/common/attachment.go +++ b/common/attachment.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "crypto/sha256" "encoding/base64" "encoding/json" @@ -8,6 +9,8 @@ import ( "github.com/gabriel-vasile/mimetype" "io/ioutil" "net/http" + "os" + "path/filepath" ) type Attachment struct { @@ -108,23 +111,68 @@ func (a *Attachment) Base64Content() (ret string, err error) { return } -func FromRow(row map[string]interface{}) (ret *Attachment, err error) { - attachment := &Attachment{} - if id, ok := row["id"].(string); ok { - attachment.ID = &id +func NewAttachment(value string) (ret *Attachment, err error) { + if isURL(value) { + var mimeType string + if mimeType, err = detectMimeTypeFromURL(value); err != nil { + return + } + ret = &Attachment{ + Type: &mimeType, + URL: &value, + } + return } - if typ, ok := row["type"].(string); ok { - attachment.Type = &typ + + var absPath string + if absPath, err = filepath.Abs(value); err != nil { + return } - if path, ok := row["path"].(string); ok { - attachment.Path = &path + if _, err = os.Stat(absPath); os.IsNotExist(err) { + err = fmt.Errorf("file %s does not exist", value) + return } - if url, ok := row["url"].(string); ok { - attachment.URL = &url + + var mimeType string + if mimeType, err = detectMimeTypeFromFile(absPath); err != nil { + return } - if content, ok := row["content"].([]byte); ok { - attachment.Content = content + ret = &Attachment{ + Type: &mimeType, + Path: &absPath, } - ret = attachment return } + +func detectMimeTypeFromBytes(content []byte) (string, error) { + mime := mimetype.Detect(content) + if mime == nil { + return "", fmt.Errorf("could not determine mimetype of stdin") + } + return mime.String(), nil +} + +func detectMimeTypeFromURL(url string) (string, error) { + resp, err := http.Head(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + mimeType := resp.Header.Get("Content-Type") + if mimeType == "" { + return "", fmt.Errorf("could not determine mimetype of URL") + } + return mimeType, nil +} + +func detectMimeTypeFromFile(path string) (string, error) { + mime, err := mimetype.DetectFile(path) + if err != nil { + return "", err + } + return mime.String(), nil +} + +func isURL(value string) bool { + return bytes.Contains([]byte(value), []byte("://")) +} diff --git a/common/domain.go b/common/domain.go index 525ededc..40633d4c 100644 --- a/common/domain.go +++ b/common/domain.go @@ -9,7 +9,7 @@ type ChatRequest struct { SessionName string PatternName string PatternVariables map[string]string - Message string + Message *goopenai.ChatCompletionMessage Language string Meta string } diff --git a/core/chatter.go b/core/chatter.go index 029480f3..b84ee483 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -108,21 +108,25 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session * if request.Language != "" { systemMessage = fmt.Sprintf("%s. Please use the language '%s' for the output.", systemMessage, request.Language) } - userMessage := strings.TrimSpace(request.Message) if raw { - // use the user role instead of the system role in raw mode - message := systemMessage + userMessage - if message != "" { - session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleUser, Content: message}) + if request.Message != nil { + if systemMessage != "" { + request.Message.Content = systemMessage + request.Message.Content + } + } else { + if systemMessage != "" { + request.Message = &goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage} + } } } else { if systemMessage != "" { session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage}) } - if userMessage != "" { - session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleUser, Content: userMessage}) - } + } + + if request.Message != nil { + session.Append(request.Message) } if session.IsEmpty() {