From 09e01eddf4f1fee87fedac7a38d6a0cdf30532fe Mon Sep 17 00:00:00 2001 From: Kayvan Sylvan Date: Sat, 28 Jun 2025 07:28:49 -0700 Subject: [PATCH] refactor: abstract chat message structs and migrate to official openai-go SDK ### CHANGES - Introduce local `chat` package for message abstraction - Replace sashabaranov/go-openai with official openai-go SDK - Update OpenAI, Azure, and Exolab plugins for new client - Refactor all AI providers to use internal chat types - Decouple codebase from third-party AI provider structs - Replace deprecated `ioutil` functions with `os` equivalents --- chat/chat.go | 132 ++++++++++++++++++++++ cli/flags.go | 22 ++-- common/domain.go | 10 +- common/domain_test.go | 22 ++-- core/chatter.go | 26 ++--- go.mod | 2 +- go.sum | 4 +- nix/pkgs/fabric/gomod2nix.toml | 3 - plugins/ai/anthropic/anthropic.go | 14 +-- plugins/ai/azure/azure.go | 15 ++- plugins/ai/bedrock/bedrock.go | 14 +-- plugins/ai/dryrun/dryrun.go | 18 +-- plugins/ai/dryrun/dryrun_test.go | 4 +- plugins/ai/exolab/exolab.go | 14 ++- plugins/ai/gemini/gemini.go | 8 +- plugins/ai/lmstudio/lmstudio.go | 6 +- plugins/ai/ollama/ollama.go | 10 +- plugins/ai/openai/openai.go | 163 ++++++++++++---------------- plugins/ai/openai/openai_test.go | 72 ++++-------- plugins/ai/perplexity/perplexity.go | 6 +- plugins/ai/vendor.go | 6 +- plugins/db/fsdb/sessions.go | 18 +-- plugins/db/fsdb/sessions_test.go | 4 +- plugins/template/sys_test.go | 2 +- restapi/chat.go | 7 +- restapi/ollama.go | 10 +- restapi/strategies.go | 5 +- 27 files changed, 346 insertions(+), 271 deletions(-) create mode 100644 chat/chat.go diff --git a/chat/chat.go b/chat/chat.go new file mode 100644 index 00000000..b5e7f027 --- /dev/null +++ b/chat/chat.go @@ -0,0 +1,132 @@ +package chat + +import ( + "encoding/json" + "errors" +) + +const ( + ChatMessageRoleSystem = "system" + ChatMessageRoleUser = "user" + ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" + ChatMessageRoleDeveloper = "developer" +) + +var ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` +} + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type ToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type ToolType `json:"type"` + Function FunctionCall `json:"function"` +} + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + + msg := struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} diff --git a/cli/flags.go b/cli/flags.go index 39bb39ce..37d3cbd9 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -10,9 +10,9 @@ import ( "strconv" "strings" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/jessevdk/go-flags" - goopenai "github.com/sashabaranov/go-openai" "golang.org/x/text/language" "gopkg.in/yaml.v2" ) @@ -278,15 +278,15 @@ func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest, err erro Meta: Meta, } - var message *goopenai.ChatCompletionMessage + var message *chat.ChatCompletionMessage if len(o.Attachments) > 0 { - message = &goopenai.ChatCompletionMessage{ - Role: goopenai.ChatMessageRoleUser, + message = &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, } if o.Message != "" { - message.MultiContent = append(message.MultiContent, goopenai.ChatMessagePart{ - Type: goopenai.ChatMessagePartTypeText, + message.MultiContent = append(message.MultiContent, chat.ChatMessagePart{ + Type: chat.ChatMessagePartTypeText, Text: strings.TrimSpace(o.Message), }) } @@ -309,16 +309,16 @@ func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest, err erro dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Image) url = &dataURL } - message.MultiContent = append(message.MultiContent, goopenai.ChatMessagePart{ - Type: goopenai.ChatMessagePartTypeImageURL, - ImageURL: &goopenai.ChatMessageImageURL{ + message.MultiContent = append(message.MultiContent, chat.ChatMessagePart{ + Type: chat.ChatMessagePartTypeImageURL, + ImageURL: &chat.ChatMessageImageURL{ URL: *url, }, }) } } else if o.Message != "" { - message = &goopenai.ChatCompletionMessage{ - Role: goopenai.ChatMessageRoleUser, + message = &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, Content: strings.TrimSpace(o.Message), } } diff --git a/common/domain.go b/common/domain.go index 345eebca..c0554ce5 100644 --- a/common/domain.go +++ b/common/domain.go @@ -1,6 +1,6 @@ package common -import goopenai "github.com/sashabaranov/go-openai" +import "github.com/danielmiessler/fabric/chat" const ChatMessageRoleMeta = "meta" @@ -9,7 +9,7 @@ type ChatRequest struct { SessionName string PatternName string PatternVariables map[string]string - Message *goopenai.ChatCompletionMessage + Message *chat.ChatCompletionMessage Language string Meta string InputHasVars bool @@ -29,7 +29,7 @@ type ChatOptions struct { } // NormalizeMessages remove empty messages and ensure messages order user-assist-user -func NormalizeMessages(msgs []*goopenai.ChatCompletionMessage, defaultUserMessage string) (ret []*goopenai.ChatCompletionMessage) { +func NormalizeMessages(msgs []*chat.ChatCompletionMessage, defaultUserMessage string) (ret []*chat.ChatCompletionMessage) { // Iterate over messages to enforce the odd position rule for user messages fullMessageIndex := 0 for _, message := range msgs { @@ -39,8 +39,8 @@ func NormalizeMessages(msgs []*goopenai.ChatCompletionMessage, defaultUserMessag } // Ensure, that each odd position shall be a user message - if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser { - ret = append(ret, &goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage}) + if fullMessageIndex%2 == 0 && message.Role != chat.ChatMessageRoleUser { + ret = append(ret, &chat.ChatCompletionMessage{Role: chat.ChatMessageRoleUser, Content: defaultUserMessage}) fullMessageIndex++ } ret = append(ret, message) diff --git a/common/domain_test.go b/common/domain_test.go index a8f39c95..3fe5dba9 100644 --- a/common/domain_test.go +++ b/common/domain_test.go @@ -3,23 +3,23 @@ package common import ( "testing" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" "github.com/stretchr/testify/assert" ) func TestNormalizeMessages(t *testing.T) { - msgs := []*goopenai.ChatCompletionMessage{ - {Role: goopenai.ChatMessageRoleUser, Content: "Hello"}, - {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"}, - {Role: goopenai.ChatMessageRoleUser, Content: ""}, - {Role: goopenai.ChatMessageRoleUser, Content: ""}, - {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"}, + msgs := []*chat.ChatCompletionMessage{ + {Role: chat.ChatMessageRoleUser, Content: "Hello"}, + {Role: chat.ChatMessageRoleAssistant, Content: "Hi there!"}, + {Role: chat.ChatMessageRoleUser, Content: ""}, + {Role: chat.ChatMessageRoleUser, Content: ""}, + {Role: chat.ChatMessageRoleUser, Content: "How are you?"}, } - expected := []*goopenai.ChatCompletionMessage{ - {Role: goopenai.ChatMessageRoleUser, Content: "Hello"}, - {Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"}, - {Role: goopenai.ChatMessageRoleUser, Content: "How are you?"}, + expected := []*chat.ChatCompletionMessage{ + {Role: chat.ChatMessageRoleUser, Content: "Hello"}, + {Role: chat.ChatMessageRoleAssistant, Content: "Hi there!"}, + {Role: chat.ChatMessageRoleUser, Content: "How are you?"}, } actual := NormalizeMessages(msgs, "default") diff --git a/core/chatter.go b/core/chatter.go index 8d6cc08a..c4c05ac1 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -7,7 +7,7 @@ import ( "os" "strings" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins/ai" @@ -110,7 +110,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (s message = summary } - session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleAssistant, Content: message}) + session.Append(&chat.ChatCompletionMessage{Role: chat.ChatMessageRoleAssistant, Content: message}) if session.Name != "" { err = o.db.Sessions.SaveSession(session) @@ -131,7 +131,7 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session * } if request.Meta != "" { - session.Append(&goopenai.ChatCompletionMessage{Role: common.ChatMessageRoleMeta, Content: request.Meta}) + session.Append(&chat.ChatCompletionMessage{Role: common.ChatMessageRoleMeta, Content: request.Meta}) } // if a context name is provided, retrieve it from the database @@ -149,8 +149,8 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session * // Double curly braces {{variable}} indicate template substitution // Ensure we have a message before processing if request.Message == nil { - request.Message = &goopenai.ChatCompletionMessage{ - Role: goopenai.ChatMessageRoleUser, + request.Message = &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, Content: " ", } } @@ -206,26 +206,26 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session * // Handle MultiContent properly in raw mode if len(request.Message.MultiContent) > 0 { // When we have attachments, add the text as a text part in MultiContent - newMultiContent := []goopenai.ChatMessagePart{ + newMultiContent := []chat.ChatMessagePart{ { - Type: goopenai.ChatMessagePartTypeText, + Type: chat.ChatMessagePartTypeText, Text: finalContent, }, } // Add existing non-text parts (like images) for _, part := range request.Message.MultiContent { - if part.Type != goopenai.ChatMessagePartTypeText { + if part.Type != chat.ChatMessagePartTypeText { newMultiContent = append(newMultiContent, part) } } - request.Message = &goopenai.ChatCompletionMessage{ - Role: goopenai.ChatMessageRoleUser, + request.Message = &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, MultiContent: newMultiContent, } } else { // No attachments, use regular Content field - request.Message = &goopenai.ChatCompletionMessage{ - Role: goopenai.ChatMessageRoleUser, + request.Message = &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, Content: finalContent, } } @@ -235,7 +235,7 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session * } } else { if systemMessage != "" { - session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage}) + session.Append(&chat.ChatCompletionMessage{Role: chat.ChatMessageRoleSystem, Content: systemMessage}) } // If multi-part content, it is in the user message, and should be added. // Otherwise, we should only add it if we have not already used it in the systemMessage. diff --git a/go.mod b/go.mod index 4365f088..03e5efb3 100644 --- a/go.mod +++ b/go.mod @@ -19,10 +19,10 @@ require ( github.com/jessevdk/go-flags v1.6.1 github.com/joho/godotenv v1.5.1 github.com/ollama/ollama v0.9.0 + github.com/openai/openai-go v1.8.2 github.com/otiai10/copy v1.14.1 github.com/pkg/errors v0.9.1 github.com/samber/lo v1.50.0 - github.com/sashabaranov/go-openai v1.40.3 github.com/sgaunet/perplexity-go/v2 v2.8.0 github.com/stretchr/testify v1.10.0 golang.org/x/text v0.26.0 diff --git a/go.sum b/go.sum index 967068cf..a3a2cefe 100644 --- a/go.sum +++ b/go.sum @@ -172,6 +172,8 @@ github.com/ollama/ollama v0.9.0 h1:GvdGhi8G/QMnFrY0TMLDy1bXua+Ify8KTkFe4ZY/OZs= github.com/ollama/ollama v0.9.0/go.mod h1:aio9yQ7nc4uwIbn6S0LkGEPgn8/9bNQLL1nHuH+OcD0= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8= +github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/otiai10/copy v1.14.1 h1:5/7E6qsUMBaH5AnQ0sSLzzTg1oTECmcCmT6lvF45Na8= github.com/otiai10/copy v1.14.1/go.mod h1:oQwrEDDOci3IM8dJF0d8+jnbfPDllW6vUjNc3DoZm9I= github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs= @@ -189,8 +191,6 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/samber/lo v1.50.0 h1:XrG0xOeHs+4FQ8gJR97zDz5uOFMW7OwFWiFVzqopKgY= github.com/samber/lo v1.50.0/go.mod h1:RjZyNk6WSnUFRKK6EyOhsRJMqft3G+pg7dCWHQCWvsc= -github.com/sashabaranov/go-openai v1.40.3 h1:PkOw0SK34wrvYVOuXF1HZzuTBRh992qRZHil4kG3eYE= -github.com/sashabaranov/go-openai v1.40.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg= github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= diff --git a/nix/pkgs/fabric/gomod2nix.toml b/nix/pkgs/fabric/gomod2nix.toml index b4dde69d..08bee61b 100644 --- a/nix/pkgs/fabric/gomod2nix.toml +++ b/nix/pkgs/fabric/gomod2nix.toml @@ -229,9 +229,6 @@ schema = 3 [mod."github.com/samber/lo"] version = "v1.50.0" hash = "sha256-KDFks82BKu39sGt0f972IyOkohV2U0r1YvsnlNLdugY=" - [mod."github.com/sashabaranov/go-openai"] - version = "v1.40.3" - hash = "sha256-Q2+la99lgKwcpgGHuf5p23gH5hmhYKp77YDUtbt35eo=" [mod."github.com/sergi/go-diff"] version = "v1.4.0" hash = "sha256-rs9NKpv/qcQEMRg7CmxGdP4HGuFdBxlpWf9LbA9wS4k=" diff --git a/plugins/ai/anthropic/anthropic.go b/plugins/ai/anthropic/anthropic.go index 7ef1638c..0e7d947a 100644 --- a/plugins/ai/anthropic/anthropic.go +++ b/plugins/ai/anthropic/anthropic.go @@ -9,9 +9,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins" - goopenai "github.com/sashabaranov/go-openai" ) const defaultBaseUrl = "https://api.anthropic.com/" @@ -87,7 +87,7 @@ func (an *Client) ListModels() (ret []string, err error) { } func (an *Client) SendStream( - msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string, ) (err error) { messages := an.toMessages(msgs) if len(messages) == 0 { @@ -151,7 +151,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common return } -func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) ( +func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) ( ret string, err error) { messages := an.toMessages(msgs) @@ -176,7 +176,7 @@ func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessa return } -func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anthropic.MessageParam) { +func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropic.MessageParam) { // Custom normalization for Anthropic: // - System messages become the first part of the first user message. // - Messages must alternate user/assistant. @@ -193,14 +193,14 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth } switch msg.Role { - case goopenai.ChatMessageRoleSystem: + case chat.ChatMessageRoleSystem: // Accumulate system content. It will be prepended to the first user message. if systemContent != "" { systemContent += "\\n" + msg.Content } else { systemContent = msg.Content } - case goopenai.ChatMessageRoleUser: + case chat.ChatMessageRoleUser: userContent := msg.Content if isFirstUserMessage && systemContent != "" { userContent = systemContent + "\\n\\n" + userContent @@ -213,7 +213,7 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth } anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(userContent))) lastRoleWasUser = true - case goopenai.ChatMessageRoleAssistant: + case chat.ChatMessageRoleAssistant: // If the first message is an assistant message, and we have system content, // prepend a user message with the system content. if isFirstUserMessage && systemContent != "" { diff --git a/plugins/ai/azure/azure.go b/plugins/ai/azure/azure.go index a2f6c63f..89cf08fc 100644 --- a/plugins/ai/azure/azure.go +++ b/plugins/ai/azure/azure.go @@ -5,7 +5,8 @@ import ( "github.com/danielmiessler/fabric/plugins" "github.com/danielmiessler/fabric/plugins/ai/openai" - goopenai "github.com/sashabaranov/go-openai" + openaiapi "github.com/openai/openai-go" + "github.com/openai/openai-go/option" ) func NewClient() (ret *Client) { @@ -29,11 +30,15 @@ type Client struct { func (oi *Client) configure() (err error) { oi.apiDeployments = strings.Split(oi.ApiDeployments.Value, ",") - config := goopenai.DefaultAzureConfig(oi.ApiKey.Value, oi.ApiBaseURL.Value) - if oi.ApiVersion.Value != "" { - config.APIVersion = oi.ApiVersion.Value + opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)} + if oi.ApiBaseURL.Value != "" { + opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value)) } - oi.ApiClient = goopenai.NewClientWithConfig(config) + if oi.ApiVersion.Value != "" { + opts = append(opts, option.WithQuery("api-version", oi.ApiVersion.Value)) + } + client := openaiapi.NewClient(opts...) + oi.ApiClient = &client return } diff --git a/plugins/ai/bedrock/bedrock.go b/plugins/ai/bedrock/bedrock.go index 87f25e75..e3fc705d 100644 --- a/plugins/ai/bedrock/bedrock.go +++ b/plugins/ai/bedrock/bedrock.go @@ -20,7 +20,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" ) const ( @@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) { } // SendStream sends the messages to the the Bedrock ConverseStream API -func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { +func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { // Ensure channel is closed on all exit paths to prevent goroutine leaks defer func() { if r := recover(); r != nil { @@ -208,7 +208,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts } // Send sends the messages the Bedrock Converse API -func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { +func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { messages := c.toMessages(msgs) @@ -249,12 +249,12 @@ func (c *BedrockClient) NeedsRawMode(modelName string) bool { // Bedrock Converse Message type. // The system role messages are mapped to the user role as they contain a mix of system messages, // pattern content and user input. -func (c *BedrockClient) toMessages(inputMessages []*goopenai.ChatCompletionMessage) (messages []types.Message) { +func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) { for _, msg := range inputMessages { roles := map[string]types.ConversationRole{ - goopenai.ChatMessageRoleUser: types.ConversationRoleUser, - goopenai.ChatMessageRoleAssistant: types.ConversationRoleAssistant, - goopenai.ChatMessageRoleSystem: types.ConversationRoleUser, + chat.ChatMessageRoleUser: types.ConversationRoleUser, + chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant, + chat.ChatMessageRoleSystem: types.ConversationRoleUser, } role, ok := roles[msg.Role] diff --git a/plugins/ai/dryrun/dryrun.go b/plugins/ai/dryrun/dryrun.go index 1f5e64df..01b3a112 100644 --- a/plugins/ai/dryrun/dryrun.go +++ b/plugins/ai/dryrun/dryrun.go @@ -6,7 +6,7 @@ import ( "fmt" "strings" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins" @@ -24,14 +24,14 @@ func (c *Client) ListModels() ([]string, error) { return []string{"dry-run-model"}, nil } -func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage) string { +func (c *Client) formatMultiContentMessage(msg *chat.ChatCompletionMessage) string { var builder strings.Builder if len(msg.MultiContent) > 0 { builder.WriteString(fmt.Sprintf("%s:\n", msg.Role)) for _, part := range msg.MultiContent { builder.WriteString(fmt.Sprintf(" - Type: %s\n", part.Type)) - if part.Type == goopenai.ChatMessagePartTypeImageURL { + if part.Type == chat.ChatMessagePartTypeImageURL { builder.WriteString(fmt.Sprintf(" Image URL: %s\n", part.ImageURL.URL)) } else { builder.WriteString(fmt.Sprintf(" Text: %s\n", part.Text)) @@ -45,16 +45,16 @@ func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage) return builder.String() } -func (c *Client) formatMessages(msgs []*goopenai.ChatCompletionMessage) string { +func (c *Client) formatMessages(msgs []*chat.ChatCompletionMessage) string { var builder strings.Builder for _, msg := range msgs { switch msg.Role { - case goopenai.ChatMessageRoleSystem: + case chat.ChatMessageRoleSystem: builder.WriteString(fmt.Sprintf("System:\n%s\n\n", msg.Content)) - case goopenai.ChatMessageRoleAssistant: + case chat.ChatMessageRoleAssistant: builder.WriteString(c.formatMultiContentMessage(msg)) - case goopenai.ChatMessageRoleUser: + case chat.ChatMessageRoleUser: builder.WriteString(c.formatMultiContentMessage(msg)) default: builder.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)) @@ -80,7 +80,7 @@ func (c *Client) formatOptions(opts *common.ChatOptions) string { return builder.String() } -func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error { var builder strings.Builder builder.WriteString("Dry run: Would send the following request:\n\n") builder.WriteString(c.formatMessages(msgs)) @@ -91,7 +91,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common return nil } -func (c *Client) Send(_ context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) { +func (c *Client) Send(_ context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) { fmt.Println("Dry run: Would send the following request:") fmt.Print(c.formatMessages(msgs)) fmt.Print(c.formatOptions(opts)) diff --git a/plugins/ai/dryrun/dryrun_test.go b/plugins/ai/dryrun/dryrun_test.go index ba20a420..e2f3a5f8 100644 --- a/plugins/ai/dryrun/dryrun_test.go +++ b/plugins/ai/dryrun/dryrun_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" - "github.com/sashabaranov/go-openai" ) // Test generated using Keploy @@ -33,7 +33,7 @@ func TestSetup_ReturnsNil(t *testing.T) { // Test generated using Keploy func TestSendStream_SendsMessages(t *testing.T) { client := NewClient() - msgs := []*openai.ChatCompletionMessage{ + msgs := []*chat.ChatCompletionMessage{ {Role: "user", Content: "Test message"}, } opts := &common.ChatOptions{ diff --git a/plugins/ai/exolab/exolab.go b/plugins/ai/exolab/exolab.go index 2213876d..e80e983e 100644 --- a/plugins/ai/exolab/exolab.go +++ b/plugins/ai/exolab/exolab.go @@ -5,8 +5,8 @@ import ( "github.com/danielmiessler/fabric/plugins" "github.com/danielmiessler/fabric/plugins/ai/openai" - - goopenai "github.com/sashabaranov/go-openai" + openaiapi "github.com/openai/openai-go" + "github.com/openai/openai-go/option" ) func NewClient() (ret *Client) { @@ -32,10 +32,12 @@ type Client struct { func (oi *Client) configure() (err error) { oi.apiModels = strings.Split(oi.ApiModels.Value, ",") - config := goopenai.DefaultConfig("") - config.BaseURL = oi.ApiBaseURL.Value - - oi.ApiClient = goopenai.NewClientWithConfig(config) + opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)} + if oi.ApiBaseURL.Value != "" { + opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value)) + } + client := openaiapi.NewClient(opts...) + oi.ApiClient = &client return } diff --git a/plugins/ai/gemini/gemini.go b/plugins/ai/gemini/gemini.go index de9fc475..73404208 100644 --- a/plugins/ai/gemini/gemini.go +++ b/plugins/ai/gemini/gemini.go @@ -6,8 +6,8 @@ import ( "fmt" "strings" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/plugins" - goopenai "github.com/sashabaranov/go-openai" "github.com/danielmiessler/fabric/common" "github.com/google/generative-ai-go/genai" @@ -60,7 +60,7 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { systemInstruction, messages := toMessages(msgs) var client *genai.Client @@ -91,7 +91,7 @@ func (o *Client) buildModelNameFull(modelName string) string { return fmt.Sprintf("%v%v", modelsNamePrefix, modelName) } -func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { ctx := context.Background() var client *genai.Client if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { @@ -147,7 +147,7 @@ func (o *Client) NeedsRawMode(modelName string) bool { return false } -func toMessages(msgs []*goopenai.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) { +func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) { if len(msgs) >= 2 { systemInstruction = &genai.Content{ Parts: []genai.Part{ diff --git a/plugins/ai/lmstudio/lmstudio.go b/plugins/ai/lmstudio/lmstudio.go index 55c8dd82..a0c7a69b 100644 --- a/plugins/ai/lmstudio/lmstudio.go +++ b/plugins/ai/lmstudio/lmstudio.go @@ -9,7 +9,7 @@ import ( "io" "net/http" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins" @@ -87,7 +87,7 @@ func (c *Client) ListModels() ([]string, error) { return models, nil } -func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) payload := map[string]interface{}{ @@ -173,7 +173,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common return } -func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) { +func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) { url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) payload := map[string]interface{}{ diff --git a/plugins/ai/ollama/ollama.go b/plugins/ai/ollama/ollama.go index 39e39943..55a97453 100644 --- a/plugins/ai/ollama/ollama.go +++ b/plugins/ai/ollama/ollama.go @@ -8,9 +8,9 @@ import ( "strings" "time" + "github.com/danielmiessler/fabric/chat" ollamaapi "github.com/ollama/ollama/api" "github.com/samber/lo" - goopenai "github.com/sashabaranov/go-openai" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins" @@ -97,7 +97,7 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) { req := o.createChatRequest(msgs, opts) respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) { @@ -115,7 +115,7 @@ func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common return } -func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { bf := false req := o.createChatRequest(msgs, opts) @@ -132,8 +132,8 @@ func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag return } -func (o *Client) createChatRequest(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) { - messages := lo.Map(msgs, func(message *goopenai.ChatCompletionMessage, _ int) (ret ollamaapi.Message) { +func (o *Client) createChatRequest(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) { + messages := lo.Map(msgs, func(message *chat.ChatCompletionMessage, _ int) (ret ollamaapi.Message) { return ollamaapi.Message{Role: message.Role, Content: message.Content} }) diff --git a/plugins/ai/openai/openai.go b/plugins/ai/openai/openai.go index d9107021..cc23eb43 100644 --- a/plugins/ai/openai/openai.go +++ b/plugins/ai/openai/openai.go @@ -2,16 +2,16 @@ package openai import ( "context" - "errors" - "fmt" - "io" "log/slog" "slices" "strings" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/plugins" - goopenai "github.com/sashabaranov/go-openai" + openai "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" ) func NewClient() (ret *Client) { @@ -48,73 +48,53 @@ type Client struct { *plugins.PluginBase ApiKey *plugins.SetupQuestion ApiBaseURL *plugins.SetupQuestion - ApiClient *goopenai.Client + ApiClient *openai.Client } func (o *Client) configure() (ret error) { - config := goopenai.DefaultConfig(o.ApiKey.Value) + opts := []option.RequestOption{option.WithAPIKey(o.ApiKey.Value)} if o.ApiBaseURL.Value != "" { - config.BaseURL = o.ApiBaseURL.Value + opts = append(opts, option.WithBaseURL(o.ApiBaseURL.Value)) } - o.ApiClient = goopenai.NewClientWithConfig(config) + client := openai.NewClient(opts...) + o.ApiClient = &client return } func (o *Client) ListModels() (ret []string, err error) { - var models goopenai.ModelsList - if models, err = o.ApiClient.ListModels(context.Background()); err != nil { + var page *pagination.Page[openai.Model] + if page, err = o.ApiClient.Models.List(context.Background()); err != nil { return } - - model := models.Models - for _, mod := range model { + for _, mod := range page.Data { ret = append(ret, mod.ID) } return } func (o *Client) SendStream( - msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string, ) (err error) { - req := o.buildChatCompletionRequest(msgs, opts) - req.Stream = true - - var stream *goopenai.ChatCompletionStream - if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil { - fmt.Printf("ChatCompletionStream error: %v\n", err) - return - } - - defer stream.Close() - - for { - var response goopenai.ChatCompletionStreamResponse - if response, err = stream.Recv(); err == nil { - if len(response.Choices) > 0 { - channel <- response.Choices[0].Delta.Content - } else { - channel <- "\n" - close(channel) - break - } - } else if errors.Is(err, io.EOF) { - channel <- "\n" - close(channel) - err = nil - break - } else if err != nil { - fmt.Printf("\nStream error: %v\n", err) - break + req := o.buildChatCompletionParams(msgs, opts) + stream := o.ApiClient.Chat.Completions.NewStreaming(context.Background(), req) + for stream.Next() { + chunk := stream.Current() + if len(chunk.Choices) > 0 { + channel <- chunk.Choices[0].Delta.Content } } - return + if stream.Err() == nil { + channel <- "\n" + } + close(channel) + return stream.Err() } -func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { - req := o.buildChatCompletionRequest(msgs, opts) +func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) { + req := o.buildChatCompletionParams(msgs, opts) - var resp goopenai.ChatCompletionResponse - if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil { + var resp *openai.ChatCompletion + if resp, err = o.ApiClient.Chat.Completions.New(ctx, req); err != nil { return } if len(resp.Choices) > 0 { @@ -146,57 +126,56 @@ func (o *Client) NeedsRawMode(modelName string) bool { return slices.Contains(openAIModelsNeedingRaw, modelName) } -func (o *Client) buildChatCompletionRequest( - inputMsgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, -) (ret goopenai.ChatCompletionRequest) { +func (o *Client) buildChatCompletionParams( + inputMsgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, +) (ret openai.ChatCompletionNewParams) { // Create a new slice for messages to be sent, converting from []*Msg to []Msg. // This also serves as a mutable copy for provider-specific modifications. - messagesForRequest := make([]goopenai.ChatCompletionMessage, len(inputMsgs)) + messagesForRequest := make([]openai.ChatCompletionMessageParamUnion, len(inputMsgs)) for i, msgPtr := range inputMsgs { - messagesForRequest[i] = *msgPtr // Dereference and copy - } - - // Provider-specific modification for DeepSeek: - // DeepSeek requires the last message to be a user message. - // If fabric constructs a single system message (common when a pattern includes user input), - // we change its role to user for DeepSeek. - if strings.Contains(opts.Model, "deepseek") { // Heuristic to identify DeepSeek models - if len(messagesForRequest) == 1 && messagesForRequest[0].Role == goopenai.ChatMessageRoleSystem { - messagesForRequest[0].Role = goopenai.ChatMessageRoleUser + msg := *msgPtr // copy + // Provider-specific modification for DeepSeek: + if strings.Contains(opts.Model, "deepseek") && len(inputMsgs) == 1 && msg.Role == chat.ChatMessageRoleSystem { + msg.Role = chat.ChatMessageRoleUser } - // Note: This handles the most common case arising from pattern usage. - // More complex scenarios where a multi-message sequence ends in 'system' - // are not currently expected from chatter.go's BuildSession logic for OpenAI providers - // but might require further rules if they arise. + messagesForRequest[i] = convertMessage(msg) } - - if opts.Raw { - ret = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Messages: messagesForRequest, - } - } else { - if opts.Seed == 0 { - ret = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: messagesForRequest, - } - } else { - ret = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: messagesForRequest, - Seed: &opts.Seed, - } + ret = openai.ChatCompletionNewParams{ + Model: openai.ChatModel(opts.Model), + Messages: messagesForRequest, + } + if !opts.Raw { + ret.Temperature = openai.Float(opts.Temperature) + ret.TopP = openai.Float(opts.TopP) + ret.PresencePenalty = openai.Float(opts.PresencePenalty) + ret.FrequencyPenalty = openai.Float(opts.FrequencyPenalty) + if opts.Seed != 0 { + ret.Seed = openai.Int(int64(opts.Seed)) } } return } + +func convertMessage(msg chat.ChatCompletionMessage) openai.ChatCompletionMessageParamUnion { + switch msg.Role { + case chat.ChatMessageRoleSystem: + return openai.SystemMessage(msg.Content) + case chat.ChatMessageRoleUser: + if len(msg.MultiContent) > 0 { + var parts []openai.ChatCompletionContentPartUnionParam + for _, p := range msg.MultiContent { + switch p.Type { + case chat.ChatMessagePartTypeText: + parts = append(parts, openai.TextContentPart(p.Text)) + case chat.ChatMessagePartTypeImageURL: + parts = append(parts, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{URL: p.ImageURL.URL})) + } + } + return openai.UserMessage(parts) + } + return openai.UserMessage(msg.Content) + default: + return openai.AssistantMessage(msg.Content) + } +} diff --git a/plugins/ai/openai/openai_test.go b/plugins/ai/openai/openai_test.go index 1ec162dc..e4798742 100644 --- a/plugins/ai/openai/openai_test.go +++ b/plugins/ai/openai/openai_test.go @@ -3,18 +3,18 @@ package openai import ( "testing" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" - "github.com/sashabaranov/go-openai" - goopenai "github.com/sashabaranov/go-openai" + openai "github.com/openai/openai-go" "github.com/stretchr/testify/assert" ) func TestBuildChatCompletionRequestPinSeed(t *testing.T) { - var msgs []*goopenai.ChatCompletionMessage + var msgs []*chat.ChatCompletionMessage for i := 0; i < 2; i++ { - msgs = append(msgs, &goopenai.ChatCompletionMessage{ + msgs = append(msgs, &chat.ChatCompletionMessage{ Role: "User", Content: "My msg", }) @@ -29,38 +29,22 @@ func TestBuildChatCompletionRequestPinSeed(t *testing.T) { Seed: 1, } - var expectedMessages []openai.ChatCompletionMessage - - for i := 0; i < 2; i++ { - expectedMessages = append(expectedMessages, - openai.ChatCompletionMessage{ - Role: msgs[i].Role, - Content: msgs[i].Content, - }, - ) - } - - var expectedRequest = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: expectedMessages, - Seed: &opts.Seed, - } - var client = NewClient() - request := client.buildChatCompletionRequest(msgs, opts) - assert.Equal(t, expectedRequest, request) + request := client.buildChatCompletionParams(msgs, opts) + assert.Equal(t, openai.ChatModel(opts.Model), request.Model) + assert.Equal(t, openai.Float(opts.Temperature), request.Temperature) + assert.Equal(t, openai.Float(opts.TopP), request.TopP) + assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty) + assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty) + assert.Equal(t, openai.Int(int64(opts.Seed)), request.Seed) } func TestBuildChatCompletionRequestNilSeed(t *testing.T) { - var msgs []*goopenai.ChatCompletionMessage + var msgs []*chat.ChatCompletionMessage for i := 0; i < 2; i++ { - msgs = append(msgs, &goopenai.ChatCompletionMessage{ + msgs = append(msgs, &chat.ChatCompletionMessage{ Role: "User", Content: "My msg", }) @@ -75,28 +59,12 @@ func TestBuildChatCompletionRequestNilSeed(t *testing.T) { Seed: 0, } - var expectedMessages []openai.ChatCompletionMessage - - for i := 0; i < 2; i++ { - expectedMessages = append(expectedMessages, - openai.ChatCompletionMessage{ - Role: msgs[i].Role, - Content: msgs[i].Content, - }, - ) - } - - var expectedRequest = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: expectedMessages, - Seed: nil, - } - var client = NewClient() - request := client.buildChatCompletionRequest(msgs, opts) - assert.Equal(t, expectedRequest, request) + request := client.buildChatCompletionParams(msgs, opts) + assert.Equal(t, openai.ChatModel(opts.Model), request.Model) + assert.Equal(t, openai.Float(opts.Temperature), request.Temperature) + assert.Equal(t, openai.Float(opts.TopP), request.TopP) + assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty) + assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty) + assert.False(t, request.Seed.Valid()) } diff --git a/plugins/ai/perplexity/perplexity.go b/plugins/ai/perplexity/perplexity.go index f0330f2d..bc254b62 100644 --- a/plugins/ai/perplexity/perplexity.go +++ b/plugins/ai/perplexity/perplexity.go @@ -10,7 +10,7 @@ import ( "github.com/danielmiessler/fabric/plugins" perplexity "github.com/sgaunet/perplexity-go/v2" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" ) const ( @@ -61,7 +61,7 @@ func (c *Client) ListModels() ([]string, error) { return models, nil } -func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) { +func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) { if c.client == nil { if err := c.Configure(); err != nil { return "", fmt.Errorf("failed to configure Perplexity client: %w", err) @@ -120,7 +120,7 @@ func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag return content, nil } -func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error { if c.client == nil { if err := c.Configure(); err != nil { close(channel) // Ensure channel is closed on error diff --git a/plugins/ai/vendor.go b/plugins/ai/vendor.go index 22b69408..d16c340e 100644 --- a/plugins/ai/vendor.go +++ b/plugins/ai/vendor.go @@ -3,8 +3,8 @@ package ai import ( "context" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/plugins" - goopenai "github.com/sashabaranov/go-openai" "github.com/danielmiessler/fabric/common" ) @@ -12,7 +12,7 @@ import ( type Vendor interface { plugins.Plugin ListModels() ([]string, error) - SendStream([]*goopenai.ChatCompletionMessage, *common.ChatOptions, chan string) error - Send(context.Context, []*goopenai.ChatCompletionMessage, *common.ChatOptions) (string, error) + SendStream([]*chat.ChatCompletionMessage, *common.ChatOptions, chan string) error + Send(context.Context, []*chat.ChatCompletionMessage, *common.ChatOptions) (string, error) NeedsRawMode(modelName string) bool } diff --git a/plugins/db/fsdb/sessions.go b/plugins/db/fsdb/sessions.go index 54d5b961..2c096f84 100644 --- a/plugins/db/fsdb/sessions.go +++ b/plugins/db/fsdb/sessions.go @@ -3,8 +3,8 @@ package fsdb import ( "fmt" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" - goopenai "github.com/sashabaranov/go-openai" ) type SessionsEntity struct { @@ -38,16 +38,16 @@ func (o *SessionsEntity) SaveSession(session *Session) (err error) { type Session struct { Name string - Messages []*goopenai.ChatCompletionMessage + Messages []*chat.ChatCompletionMessage - vendorMessages []*goopenai.ChatCompletionMessage + vendorMessages []*chat.ChatCompletionMessage } func (o *Session) IsEmpty() bool { return len(o.Messages) == 0 } -func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) { +func (o *Session) Append(messages ...*chat.ChatCompletionMessage) { if o.vendorMessages != nil { for _, message := range messages { o.Messages = append(o.Messages, message) @@ -58,7 +58,7 @@ func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) { } } -func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) { +func (o *Session) GetVendorMessages() (ret []*chat.ChatCompletionMessage) { if len(o.vendorMessages) == 0 { for _, message := range o.Messages { o.appendVendorMessage(message) @@ -68,13 +68,13 @@ func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) { return } -func (o *Session) appendVendorMessage(message *goopenai.ChatCompletionMessage) { +func (o *Session) appendVendorMessage(message *chat.ChatCompletionMessage) { if message.Role != common.ChatMessageRoleMeta { o.vendorMessages = append(o.vendorMessages, message) } } -func (o *Session) GetLastMessage() (ret *goopenai.ChatCompletionMessage) { +func (o *Session) GetLastMessage() (ret *chat.ChatCompletionMessage) { if len(o.Messages) > 0 { ret = o.Messages[len(o.Messages)-1] } @@ -86,9 +86,9 @@ func (o *Session) String() (ret string) { ret += fmt.Sprintf("\n--- \n[%v]\n%v", message.Role, message.Content) if message.MultiContent != nil { for _, part := range message.MultiContent { - if part.Type == goopenai.ChatMessagePartTypeImageURL { + if part.Type == chat.ChatMessagePartTypeImageURL { ret += fmt.Sprintf("\n%v: %v", part.Type, *part.ImageURL) - } else if part.Type == goopenai.ChatMessagePartTypeText { + } else if part.Type == chat.ChatMessagePartTypeText { ret += fmt.Sprintf("\n%v: %v", part.Type, part.Text) } } diff --git a/plugins/db/fsdb/sessions_test.go b/plugins/db/fsdb/sessions_test.go index 9477ce14..0a2787c3 100644 --- a/plugins/db/fsdb/sessions_test.go +++ b/plugins/db/fsdb/sessions_test.go @@ -3,7 +3,7 @@ package fsdb import ( "testing" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" ) func TestSessions_GetOrCreateSession(t *testing.T) { @@ -27,7 +27,7 @@ func TestSessions_SaveSession(t *testing.T) { StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"}, } sessionName := "testSession" - session := &Session{Name: sessionName, Messages: []*goopenai.ChatCompletionMessage{{Content: "message1"}}} + session := &Session{Name: sessionName, Messages: []*chat.ChatCompletionMessage{{Content: "message1"}}} err := sessions.SaveSession(session) if err != nil { t.Fatalf("failed to save session: %v", err) diff --git a/plugins/template/sys_test.go b/plugins/template/sys_test.go index add00590..ab55ed8d 100644 --- a/plugins/template/sys_test.go +++ b/plugins/template/sys_test.go @@ -93,7 +93,7 @@ func TestSysPlugin(t *testing.T) { if !filepath.IsAbs(got) { return fmt.Errorf("expected absolute path, got %s", got) } - if !strings.Contains(got, "home") && !strings.Contains(got, "Users") { + if !strings.Contains(got, "home") && !strings.Contains(got, "Users") && got != "/root" { return fmt.Errorf("path %s doesn't look like a home directory", got) } return nil diff --git a/restapi/chat.go b/restapi/chat.go index 36d5d79d..7ae13932 100755 --- a/restapi/chat.go +++ b/restapi/chat.go @@ -3,14 +3,13 @@ package restapi import ( "encoding/json" "fmt" - "io/ioutil" "log" "net/http" "os" "path/filepath" "strings" - goopenai "github.com/sashabaranov/go-openai" + "github.com/danielmiessler/fabric/chat" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/core" @@ -95,7 +94,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) { // Load and prepend strategy prompt if strategyName is set if p.StrategyName != "" { strategyFile := filepath.Join(os.Getenv("HOME"), ".config", "fabric", "strategies", p.StrategyName+".json") - data, err := ioutil.ReadFile(strategyFile) + data, err := os.ReadFile(strategyFile) if err == nil { var s struct { Prompt string `json:"prompt"` @@ -115,7 +114,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) { // Pass the language received in the initial request to the common.ChatRequest chatReq := &common.ChatRequest{ - Message: &goopenai.ChatCompletionMessage{ + Message: &chat.ChatCompletionMessage{ Role: "user", Content: p.UserInput, }, diff --git a/restapi/ollama.go b/restapi/ollama.go index fbc6d02a..c96e1ebf 100644 --- a/restapi/ollama.go +++ b/restapi/ollama.go @@ -103,7 +103,6 @@ func ServeOllama(registry *core.PluginRegistry, address string, version string) r.GET("/api/tags", typeConversion.ollamaTags) r.GET("/api/version", func(c *gin.Context) { c.Data(200, "application/json", []byte(fmt.Sprintf("{\"%s\"}", version))) - return }) r.POST("/api/chat", typeConversion.ollamaChat) @@ -262,15 +261,10 @@ func (f APIConvert) ollamaChat(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err}) return } - for _, bytein := range marshalled { - res = append(res, bytein) - } - for _, bytebreak := range []byte("\n") { - res = append(res, bytebreak) - } + res = append(res, marshalled...) + res = append(res, '\n') } c.Data(200, "application/json", res) //c.JSON(200, forwardedResponse) - return } diff --git a/restapi/strategies.go b/restapi/strategies.go index 06e35c7d..f91f4e4f 100644 --- a/restapi/strategies.go +++ b/restapi/strategies.go @@ -2,7 +2,6 @@ package restapi import ( "encoding/json" - "io/ioutil" "net/http" "os" "path/filepath" @@ -23,7 +22,7 @@ func NewStrategiesHandler(r *gin.Engine) { r.GET("/strategies", func(c *gin.Context) { strategiesDir := filepath.Join(os.Getenv("HOME"), ".config", "fabric", "strategies") - files, err := ioutil.ReadDir(strategiesDir) + files, err := os.ReadDir(strategiesDir) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read strategies directory"}) return @@ -37,7 +36,7 @@ func NewStrategiesHandler(r *gin.Engine) { } fullPath := filepath.Join(strategiesDir, file.Name()) - data, err := ioutil.ReadFile(fullPath) + data, err := os.ReadFile(fullPath) if err != nil { continue }