mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-13 07:25:10 -05:00
refactor: abstract chat message structs and migrate to official openai-go SDK
### CHANGES - Introduce local `chat` package for message abstraction - Replace sashabaranov/go-openai with official openai-go SDK - Update OpenAI, Azure, and Exolab plugins for new client - Refactor all AI providers to use internal chat types - Decouple codebase from third-party AI provider structs - Replace deprecated `ioutil` functions with `os` equivalents
This commit is contained in:
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
const defaultBaseUrl = "https://api.anthropic.com/"
|
||||
@@ -87,7 +87,7 @@ func (an *Client) ListModels() (ret []string, err error) {
|
||||
}
|
||||
|
||||
func (an *Client) SendStream(
|
||||
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
messages := an.toMessages(msgs)
|
||||
if len(messages) == 0 {
|
||||
@@ -151,7 +151,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (
|
||||
func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (
|
||||
ret string, err error) {
|
||||
|
||||
messages := an.toMessages(msgs)
|
||||
@@ -176,7 +176,7 @@ func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessa
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anthropic.MessageParam) {
|
||||
func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropic.MessageParam) {
|
||||
// Custom normalization for Anthropic:
|
||||
// - System messages become the first part of the first user message.
|
||||
// - Messages must alternate user/assistant.
|
||||
@@ -193,14 +193,14 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
|
||||
}
|
||||
|
||||
switch msg.Role {
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
case chat.ChatMessageRoleSystem:
|
||||
// Accumulate system content. It will be prepended to the first user message.
|
||||
if systemContent != "" {
|
||||
systemContent += "\\n" + msg.Content
|
||||
} else {
|
||||
systemContent = msg.Content
|
||||
}
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
case chat.ChatMessageRoleUser:
|
||||
userContent := msg.Content
|
||||
if isFirstUserMessage && systemContent != "" {
|
||||
userContent = systemContent + "\\n\\n" + userContent
|
||||
@@ -213,7 +213,7 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(userContent)))
|
||||
lastRoleWasUser = true
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
// If the first message is an assistant message, and we have system content,
|
||||
// prepend a user message with the system content.
|
||||
if isFirstUserMessage && systemContent != "" {
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/openai"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
@@ -29,11 +30,15 @@ type Client struct {
|
||||
|
||||
func (oi *Client) configure() (err error) {
|
||||
oi.apiDeployments = strings.Split(oi.ApiDeployments.Value, ",")
|
||||
config := goopenai.DefaultAzureConfig(oi.ApiKey.Value, oi.ApiBaseURL.Value)
|
||||
if oi.ApiVersion.Value != "" {
|
||||
config.APIVersion = oi.ApiVersion.Value
|
||||
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
|
||||
if oi.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
|
||||
}
|
||||
oi.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
if oi.ApiVersion.Value != "" {
|
||||
opts = append(opts, option.WithQuery("api-version", oi.ApiVersion.Value))
|
||||
}
|
||||
client := openaiapi.NewClient(opts...)
|
||||
oi.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) {
|
||||
}
|
||||
|
||||
// SendStream sends the messages to the the Bedrock ConverseStream API
|
||||
func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
// Ensure channel is closed on all exit paths to prevent goroutine leaks
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -208,7 +208,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
|
||||
}
|
||||
|
||||
// Send sends the messages the Bedrock Converse API
|
||||
func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
|
||||
messages := c.toMessages(msgs)
|
||||
|
||||
@@ -249,12 +249,12 @@ func (c *BedrockClient) NeedsRawMode(modelName string) bool {
|
||||
// Bedrock Converse Message type.
|
||||
// The system role messages are mapped to the user role as they contain a mix of system messages,
|
||||
// pattern content and user input.
|
||||
func (c *BedrockClient) toMessages(inputMessages []*goopenai.ChatCompletionMessage) (messages []types.Message) {
|
||||
func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) {
|
||||
for _, msg := range inputMessages {
|
||||
roles := map[string]types.ConversationRole{
|
||||
goopenai.ChatMessageRoleUser: types.ConversationRoleUser,
|
||||
goopenai.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
|
||||
goopenai.ChatMessageRoleSystem: types.ConversationRoleUser,
|
||||
chat.ChatMessageRoleUser: types.ConversationRoleUser,
|
||||
chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
|
||||
chat.ChatMessageRoleSystem: types.ConversationRoleUser,
|
||||
}
|
||||
|
||||
role, ok := roles[msg.Role]
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -24,14 +24,14 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return []string{"dry-run-model"}, nil
|
||||
}
|
||||
|
||||
func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage) string {
|
||||
func (c *Client) formatMultiContentMessage(msg *chat.ChatCompletionMessage) string {
|
||||
var builder strings.Builder
|
||||
|
||||
if len(msg.MultiContent) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", msg.Role))
|
||||
for _, part := range msg.MultiContent {
|
||||
builder.WriteString(fmt.Sprintf(" - Type: %s\n", part.Type))
|
||||
if part.Type == goopenai.ChatMessagePartTypeImageURL {
|
||||
if part.Type == chat.ChatMessagePartTypeImageURL {
|
||||
builder.WriteString(fmt.Sprintf(" Image URL: %s\n", part.ImageURL.URL))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf(" Text: %s\n", part.Text))
|
||||
@@ -45,16 +45,16 @@ func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage)
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (c *Client) formatMessages(msgs []*goopenai.ChatCompletionMessage) string {
|
||||
func (c *Client) formatMessages(msgs []*chat.ChatCompletionMessage) string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
case chat.ChatMessageRoleSystem:
|
||||
builder.WriteString(fmt.Sprintf("System:\n%s\n\n", msg.Content))
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
builder.WriteString(c.formatMultiContentMessage(msg))
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
case chat.ChatMessageRoleUser:
|
||||
builder.WriteString(c.formatMultiContentMessage(msg))
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content))
|
||||
@@ -80,7 +80,7 @@ func (c *Client) formatOptions(opts *common.ChatOptions) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Dry run: Would send the following request:\n\n")
|
||||
builder.WriteString(c.formatMessages(msgs))
|
||||
@@ -91,7 +91,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(_ context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
func (c *Client) Send(_ context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
fmt.Println("Dry run: Would send the following request:")
|
||||
fmt.Print(c.formatMessages(msgs))
|
||||
fmt.Print(c.formatOptions(opts))
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Test generated using Keploy
|
||||
@@ -33,7 +33,7 @@ func TestSetup_ReturnsNil(t *testing.T) {
|
||||
// Test generated using Keploy
|
||||
func TestSendStream_SendsMessages(t *testing.T) {
|
||||
client := NewClient()
|
||||
msgs := []*openai.ChatCompletionMessage{
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Test message"},
|
||||
}
|
||||
opts := &common.ChatOptions{
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/openai"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
@@ -32,10 +32,12 @@ type Client struct {
|
||||
func (oi *Client) configure() (err error) {
|
||||
oi.apiModels = strings.Split(oi.ApiModels.Value, ",")
|
||||
|
||||
config := goopenai.DefaultConfig("")
|
||||
config.BaseURL = oi.ApiBaseURL.Value
|
||||
|
||||
oi.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
|
||||
if oi.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
|
||||
}
|
||||
client := openaiapi.NewClient(opts...)
|
||||
oi.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
@@ -60,7 +60,7 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
systemInstruction, messages := toMessages(msgs)
|
||||
|
||||
var client *genai.Client
|
||||
@@ -91,7 +91,7 @@ func (o *Client) buildModelNameFull(modelName string) string {
|
||||
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
@@ -147,7 +147,7 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func toMessages(msgs []*goopenai.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
|
||||
func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
|
||||
if len(msgs) >= 2 {
|
||||
systemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -87,7 +87,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
@@ -173,7 +173,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
|
||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
ollamaapi "github.com/ollama/ollama/api"
|
||||
"github.com/samber/lo"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -97,7 +97,7 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
req := o.createChatRequest(msgs, opts)
|
||||
|
||||
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
|
||||
@@ -115,7 +115,7 @@ func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
bf := false
|
||||
|
||||
req := o.createChatRequest(msgs, opts)
|
||||
@@ -132,8 +132,8 @@ func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) createChatRequest(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
|
||||
messages := lo.Map(msgs, func(message *goopenai.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
|
||||
func (o *Client) createChatRequest(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
|
||||
messages := lo.Map(msgs, func(message *chat.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
|
||||
return ollamaapi.Message{Role: message.Role, Content: message.Content}
|
||||
})
|
||||
|
||||
|
||||
@@ -2,16 +2,16 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/pagination"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
@@ -48,73 +48,53 @@ type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiKey *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiClient *goopenai.Client
|
||||
ApiClient *openai.Client
|
||||
}
|
||||
|
||||
func (o *Client) configure() (ret error) {
|
||||
config := goopenai.DefaultConfig(o.ApiKey.Value)
|
||||
opts := []option.RequestOption{option.WithAPIKey(o.ApiKey.Value)}
|
||||
if o.ApiBaseURL.Value != "" {
|
||||
config.BaseURL = o.ApiBaseURL.Value
|
||||
opts = append(opts, option.WithBaseURL(o.ApiBaseURL.Value))
|
||||
}
|
||||
o.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
client := openai.NewClient(opts...)
|
||||
o.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) ListModels() (ret []string, err error) {
|
||||
var models goopenai.ModelsList
|
||||
if models, err = o.ApiClient.ListModels(context.Background()); err != nil {
|
||||
var page *pagination.Page[openai.Model]
|
||||
if page, err = o.ApiClient.Models.List(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
model := models.Models
|
||||
for _, mod := range model {
|
||||
for _, mod := range page.Data {
|
||||
ret = append(ret, mod.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(
|
||||
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
req := o.buildChatCompletionRequest(msgs, opts)
|
||||
req.Stream = true
|
||||
|
||||
var stream *goopenai.ChatCompletionStream
|
||||
if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil {
|
||||
fmt.Printf("ChatCompletionStream error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
var response goopenai.ChatCompletionStreamResponse
|
||||
if response, err = stream.Recv(); err == nil {
|
||||
if len(response.Choices) > 0 {
|
||||
channel <- response.Choices[0].Delta.Content
|
||||
} else {
|
||||
channel <- "\n"
|
||||
close(channel)
|
||||
break
|
||||
}
|
||||
} else if errors.Is(err, io.EOF) {
|
||||
channel <- "\n"
|
||||
close(channel)
|
||||
err = nil
|
||||
break
|
||||
} else if err != nil {
|
||||
fmt.Printf("\nStream error: %v\n", err)
|
||||
break
|
||||
req := o.buildChatCompletionParams(msgs, opts)
|
||||
stream := o.ApiClient.Chat.Completions.NewStreaming(context.Background(), req)
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
if len(chunk.Choices) > 0 {
|
||||
channel <- chunk.Choices[0].Delta.Content
|
||||
}
|
||||
}
|
||||
return
|
||||
if stream.Err() == nil {
|
||||
channel <- "\n"
|
||||
}
|
||||
close(channel)
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := o.buildChatCompletionRequest(msgs, opts)
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := o.buildChatCompletionParams(msgs, opts)
|
||||
|
||||
var resp goopenai.ChatCompletionResponse
|
||||
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
|
||||
var resp *openai.ChatCompletion
|
||||
if resp, err = o.ApiClient.Chat.Completions.New(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
@@ -146,57 +126,56 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
return slices.Contains(openAIModelsNeedingRaw, modelName)
|
||||
}
|
||||
|
||||
func (o *Client) buildChatCompletionRequest(
|
||||
inputMsgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions,
|
||||
) (ret goopenai.ChatCompletionRequest) {
|
||||
func (o *Client) buildChatCompletionParams(
|
||||
inputMsgs []*chat.ChatCompletionMessage, opts *common.ChatOptions,
|
||||
) (ret openai.ChatCompletionNewParams) {
|
||||
|
||||
// Create a new slice for messages to be sent, converting from []*Msg to []Msg.
|
||||
// This also serves as a mutable copy for provider-specific modifications.
|
||||
messagesForRequest := make([]goopenai.ChatCompletionMessage, len(inputMsgs))
|
||||
messagesForRequest := make([]openai.ChatCompletionMessageParamUnion, len(inputMsgs))
|
||||
for i, msgPtr := range inputMsgs {
|
||||
messagesForRequest[i] = *msgPtr // Dereference and copy
|
||||
}
|
||||
|
||||
// Provider-specific modification for DeepSeek:
|
||||
// DeepSeek requires the last message to be a user message.
|
||||
// If fabric constructs a single system message (common when a pattern includes user input),
|
||||
// we change its role to user for DeepSeek.
|
||||
if strings.Contains(opts.Model, "deepseek") { // Heuristic to identify DeepSeek models
|
||||
if len(messagesForRequest) == 1 && messagesForRequest[0].Role == goopenai.ChatMessageRoleSystem {
|
||||
messagesForRequest[0].Role = goopenai.ChatMessageRoleUser
|
||||
msg := *msgPtr // copy
|
||||
// Provider-specific modification for DeepSeek:
|
||||
if strings.Contains(opts.Model, "deepseek") && len(inputMsgs) == 1 && msg.Role == chat.ChatMessageRoleSystem {
|
||||
msg.Role = chat.ChatMessageRoleUser
|
||||
}
|
||||
// Note: This handles the most common case arising from pattern usage.
|
||||
// More complex scenarios where a multi-message sequence ends in 'system'
|
||||
// are not currently expected from chatter.go's BuildSession logic for OpenAI providers
|
||||
// but might require further rules if they arise.
|
||||
messagesForRequest[i] = convertMessage(msg)
|
||||
}
|
||||
|
||||
if opts.Raw {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messagesForRequest,
|
||||
}
|
||||
} else {
|
||||
if opts.Seed == 0 {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messagesForRequest,
|
||||
}
|
||||
} else {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messagesForRequest,
|
||||
Seed: &opts.Seed,
|
||||
}
|
||||
ret = openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(opts.Model),
|
||||
Messages: messagesForRequest,
|
||||
}
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
ret.PresencePenalty = openai.Float(opts.PresencePenalty)
|
||||
ret.FrequencyPenalty = openai.Float(opts.FrequencyPenalty)
|
||||
if opts.Seed != 0 {
|
||||
ret.Seed = openai.Int(int64(opts.Seed))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertMessage(msg chat.ChatCompletionMessage) openai.ChatCompletionMessageParamUnion {
|
||||
switch msg.Role {
|
||||
case chat.ChatMessageRoleSystem:
|
||||
return openai.SystemMessage(msg.Content)
|
||||
case chat.ChatMessageRoleUser:
|
||||
if len(msg.MultiContent) > 0 {
|
||||
var parts []openai.ChatCompletionContentPartUnionParam
|
||||
for _, p := range msg.MultiContent {
|
||||
switch p.Type {
|
||||
case chat.ChatMessagePartTypeText:
|
||||
parts = append(parts, openai.TextContentPart(p.Text))
|
||||
case chat.ChatMessagePartTypeImageURL:
|
||||
parts = append(parts, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{URL: p.ImageURL.URL}))
|
||||
}
|
||||
}
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
return openai.UserMessage(msg.Content)
|
||||
default:
|
||||
return openai.AssistantMessage(msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,18 +3,18 @@ package openai
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
||||
|
||||
var msgs []*goopenai.ChatCompletionMessage
|
||||
var msgs []*chat.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
||||
msgs = append(msgs, &chat.ChatCompletionMessage{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
@@ -29,38 +29,22 @@ func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
||||
Seed: 1,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: &opts.Seed,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
request := client.buildChatCompletionParams(msgs, opts)
|
||||
assert.Equal(t, openai.ChatModel(opts.Model), request.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
|
||||
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
|
||||
assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty)
|
||||
assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty)
|
||||
assert.Equal(t, openai.Int(int64(opts.Seed)), request.Seed)
|
||||
}
|
||||
|
||||
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
||||
|
||||
var msgs []*goopenai.ChatCompletionMessage
|
||||
var msgs []*chat.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
||||
msgs = append(msgs, &chat.ChatCompletionMessage{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
@@ -75,28 +59,12 @@ func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
||||
Seed: 0,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: nil,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
request := client.buildChatCompletionParams(msgs, opts)
|
||||
assert.Equal(t, openai.ChatModel(opts.Model), request.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
|
||||
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
|
||||
assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty)
|
||||
assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty)
|
||||
assert.False(t, request.Seed.Valid())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
perplexity "github.com/sgaunet/perplexity-go/v2"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,7 +61,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
if c.client == nil {
|
||||
if err := c.Configure(); err != nil {
|
||||
return "", fmt.Errorf("failed to configure Perplexity client: %w", err)
|
||||
@@ -120,7 +120,7 @@ func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
if c.client == nil {
|
||||
if err := c.Configure(); err != nil {
|
||||
close(channel) // Ensure channel is closed on error
|
||||
|
||||
@@ -3,8 +3,8 @@ package ai
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
type Vendor interface {
|
||||
plugins.Plugin
|
||||
ListModels() ([]string, error)
|
||||
SendStream([]*goopenai.ChatCompletionMessage, *common.ChatOptions, chan string) error
|
||||
Send(context.Context, []*goopenai.ChatCompletionMessage, *common.ChatOptions) (string, error)
|
||||
SendStream([]*chat.ChatCompletionMessage, *common.ChatOptions, chan string) error
|
||||
Send(context.Context, []*chat.ChatCompletionMessage, *common.ChatOptions) (string, error)
|
||||
NeedsRawMode(modelName string) bool
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package fsdb
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type SessionsEntity struct {
|
||||
@@ -38,16 +38,16 @@ func (o *SessionsEntity) SaveSession(session *Session) (err error) {
|
||||
|
||||
type Session struct {
|
||||
Name string
|
||||
Messages []*goopenai.ChatCompletionMessage
|
||||
Messages []*chat.ChatCompletionMessage
|
||||
|
||||
vendorMessages []*goopenai.ChatCompletionMessage
|
||||
vendorMessages []*chat.ChatCompletionMessage
|
||||
}
|
||||
|
||||
func (o *Session) IsEmpty() bool {
|
||||
return len(o.Messages) == 0
|
||||
}
|
||||
|
||||
func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) Append(messages ...*chat.ChatCompletionMessage) {
|
||||
if o.vendorMessages != nil {
|
||||
for _, message := range messages {
|
||||
o.Messages = append(o.Messages, message)
|
||||
@@ -58,7 +58,7 @@ func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) GetVendorMessages() (ret []*chat.ChatCompletionMessage) {
|
||||
if len(o.vendorMessages) == 0 {
|
||||
for _, message := range o.Messages {
|
||||
o.appendVendorMessage(message)
|
||||
@@ -68,13 +68,13 @@ func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Session) appendVendorMessage(message *goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) appendVendorMessage(message *chat.ChatCompletionMessage) {
|
||||
if message.Role != common.ChatMessageRoleMeta {
|
||||
o.vendorMessages = append(o.vendorMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Session) GetLastMessage() (ret *goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) GetLastMessage() (ret *chat.ChatCompletionMessage) {
|
||||
if len(o.Messages) > 0 {
|
||||
ret = o.Messages[len(o.Messages)-1]
|
||||
}
|
||||
@@ -86,9 +86,9 @@ func (o *Session) String() (ret string) {
|
||||
ret += fmt.Sprintf("\n--- \n[%v]\n%v", message.Role, message.Content)
|
||||
if message.MultiContent != nil {
|
||||
for _, part := range message.MultiContent {
|
||||
if part.Type == goopenai.ChatMessagePartTypeImageURL {
|
||||
if part.Type == chat.ChatMessagePartTypeImageURL {
|
||||
ret += fmt.Sprintf("\n%v: %v", part.Type, *part.ImageURL)
|
||||
} else if part.Type == goopenai.ChatMessagePartTypeText {
|
||||
} else if part.Type == chat.ChatMessagePartTypeText {
|
||||
ret += fmt.Sprintf("\n%v: %v", part.Type, part.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package fsdb
|
||||
import (
|
||||
"testing"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
func TestSessions_GetOrCreateSession(t *testing.T) {
|
||||
@@ -27,7 +27,7 @@ func TestSessions_SaveSession(t *testing.T) {
|
||||
StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"},
|
||||
}
|
||||
sessionName := "testSession"
|
||||
session := &Session{Name: sessionName, Messages: []*goopenai.ChatCompletionMessage{{Content: "message1"}}}
|
||||
session := &Session{Name: sessionName, Messages: []*chat.ChatCompletionMessage{{Content: "message1"}}}
|
||||
err := sessions.SaveSession(session)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestSysPlugin(t *testing.T) {
|
||||
if !filepath.IsAbs(got) {
|
||||
return fmt.Errorf("expected absolute path, got %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") {
|
||||
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") && got != "/root" {
|
||||
return fmt.Errorf("path %s doesn't look like a home directory", got)
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user