refactor: abstract chat message structs and migrate to official openai-go SDK

### CHANGES

- Introduce local `chat` package for message abstraction
- Replace sashabaranov/go-openai with official openai-go SDK
- Update OpenAI, Azure, and Exolab plugins for new client
- Refactor all AI providers to use internal chat types
- Decouple codebase from third-party AI provider structs
- Replace deprecated `ioutil` functions with `os` equivalents
This commit is contained in:
Kayvan Sylvan
2025-06-28 07:28:49 -07:00
parent aa028a4a57
commit 09e01eddf4
27 changed files with 346 additions and 271 deletions

View File

@@ -9,9 +9,9 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
goopenai "github.com/sashabaranov/go-openai"
)
const defaultBaseUrl = "https://api.anthropic.com/"
@@ -87,7 +87,7 @@ func (an *Client) ListModels() (ret []string, err error) {
}
func (an *Client) SendStream(
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
) (err error) {
messages := an.toMessages(msgs)
if len(messages) == 0 {
@@ -151,7 +151,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
return
}
func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (
func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (
ret string, err error) {
messages := an.toMessages(msgs)
@@ -176,7 +176,7 @@ func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessa
return
}
func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anthropic.MessageParam) {
func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropic.MessageParam) {
// Custom normalization for Anthropic:
// - System messages become the first part of the first user message.
// - Messages must alternate user/assistant.
@@ -193,14 +193,14 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
}
switch msg.Role {
case goopenai.ChatMessageRoleSystem:
case chat.ChatMessageRoleSystem:
// Accumulate system content. It will be prepended to the first user message.
if systemContent != "" {
systemContent += "\\n" + msg.Content
} else {
systemContent = msg.Content
}
case goopenai.ChatMessageRoleUser:
case chat.ChatMessageRoleUser:
userContent := msg.Content
if isFirstUserMessage && systemContent != "" {
userContent = systemContent + "\\n\\n" + userContent
@@ -213,7 +213,7 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
}
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(userContent)))
lastRoleWasUser = true
case goopenai.ChatMessageRoleAssistant:
case chat.ChatMessageRoleAssistant:
// If the first message is an assistant message, and we have system content,
// prepend a user message with the system content.
if isFirstUserMessage && systemContent != "" {

View File

@@ -5,7 +5,8 @@ import (
"github.com/danielmiessler/fabric/plugins"
"github.com/danielmiessler/fabric/plugins/ai/openai"
goopenai "github.com/sashabaranov/go-openai"
openaiapi "github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
func NewClient() (ret *Client) {
@@ -29,11 +30,15 @@ type Client struct {
func (oi *Client) configure() (err error) {
oi.apiDeployments = strings.Split(oi.ApiDeployments.Value, ",")
config := goopenai.DefaultAzureConfig(oi.ApiKey.Value, oi.ApiBaseURL.Value)
if oi.ApiVersion.Value != "" {
config.APIVersion = oi.ApiVersion.Value
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
if oi.ApiBaseURL.Value != "" {
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
}
oi.ApiClient = goopenai.NewClientWithConfig(config)
if oi.ApiVersion.Value != "" {
opts = append(opts, option.WithQuery("api-version", oi.ApiVersion.Value))
}
client := openaiapi.NewClient(opts...)
oi.ApiClient = &client
return
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/chat"
)
const (
@@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) {
}
// SendStream sends the messages to the the Bedrock ConverseStream API
func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
// Ensure channel is closed on all exit paths to prevent goroutine leaks
defer func() {
if r := recover(); r != nil {
@@ -208,7 +208,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
}
// Send sends the messages the Bedrock Converse API
func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
messages := c.toMessages(msgs)
@@ -249,12 +249,12 @@ func (c *BedrockClient) NeedsRawMode(modelName string) bool {
// Bedrock Converse Message type.
// The system role messages are mapped to the user role as they contain a mix of system messages,
// pattern content and user input.
func (c *BedrockClient) toMessages(inputMessages []*goopenai.ChatCompletionMessage) (messages []types.Message) {
func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) {
for _, msg := range inputMessages {
roles := map[string]types.ConversationRole{
goopenai.ChatMessageRoleUser: types.ConversationRoleUser,
goopenai.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
goopenai.ChatMessageRoleSystem: types.ConversationRoleUser,
chat.ChatMessageRoleUser: types.ConversationRoleUser,
chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
chat.ChatMessageRoleSystem: types.ConversationRoleUser,
}
role, ok := roles[msg.Role]

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"strings"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
@@ -24,14 +24,14 @@ func (c *Client) ListModels() ([]string, error) {
return []string{"dry-run-model"}, nil
}
func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage) string {
func (c *Client) formatMultiContentMessage(msg *chat.ChatCompletionMessage) string {
var builder strings.Builder
if len(msg.MultiContent) > 0 {
builder.WriteString(fmt.Sprintf("%s:\n", msg.Role))
for _, part := range msg.MultiContent {
builder.WriteString(fmt.Sprintf(" - Type: %s\n", part.Type))
if part.Type == goopenai.ChatMessagePartTypeImageURL {
if part.Type == chat.ChatMessagePartTypeImageURL {
builder.WriteString(fmt.Sprintf(" Image URL: %s\n", part.ImageURL.URL))
} else {
builder.WriteString(fmt.Sprintf(" Text: %s\n", part.Text))
@@ -45,16 +45,16 @@ func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage)
return builder.String()
}
func (c *Client) formatMessages(msgs []*goopenai.ChatCompletionMessage) string {
func (c *Client) formatMessages(msgs []*chat.ChatCompletionMessage) string {
var builder strings.Builder
for _, msg := range msgs {
switch msg.Role {
case goopenai.ChatMessageRoleSystem:
case chat.ChatMessageRoleSystem:
builder.WriteString(fmt.Sprintf("System:\n%s\n\n", msg.Content))
case goopenai.ChatMessageRoleAssistant:
case chat.ChatMessageRoleAssistant:
builder.WriteString(c.formatMultiContentMessage(msg))
case goopenai.ChatMessageRoleUser:
case chat.ChatMessageRoleUser:
builder.WriteString(c.formatMultiContentMessage(msg))
default:
builder.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content))
@@ -80,7 +80,7 @@ func (c *Client) formatOptions(opts *common.ChatOptions) string {
return builder.String()
}
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
var builder strings.Builder
builder.WriteString("Dry run: Would send the following request:\n\n")
builder.WriteString(c.formatMessages(msgs))
@@ -91,7 +91,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
return nil
}
func (c *Client) Send(_ context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
func (c *Client) Send(_ context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
fmt.Println("Dry run: Would send the following request:")
fmt.Print(c.formatMessages(msgs))
fmt.Print(c.formatOptions(opts))

View File

@@ -4,8 +4,8 @@ import (
"reflect"
"testing"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/sashabaranov/go-openai"
)
// Test generated using Keploy
@@ -33,7 +33,7 @@ func TestSetup_ReturnsNil(t *testing.T) {
// Test generated using Keploy
func TestSendStream_SendsMessages(t *testing.T) {
client := NewClient()
msgs := []*openai.ChatCompletionMessage{
msgs := []*chat.ChatCompletionMessage{
{Role: "user", Content: "Test message"},
}
opts := &common.ChatOptions{

View File

@@ -5,8 +5,8 @@ import (
"github.com/danielmiessler/fabric/plugins"
"github.com/danielmiessler/fabric/plugins/ai/openai"
goopenai "github.com/sashabaranov/go-openai"
openaiapi "github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
func NewClient() (ret *Client) {
@@ -32,10 +32,12 @@ type Client struct {
func (oi *Client) configure() (err error) {
oi.apiModels = strings.Split(oi.ApiModels.Value, ",")
config := goopenai.DefaultConfig("")
config.BaseURL = oi.ApiBaseURL.Value
oi.ApiClient = goopenai.NewClientWithConfig(config)
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
if oi.ApiBaseURL.Value != "" {
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
}
client := openaiapi.NewClient(opts...)
oi.ApiClient = &client
return
}

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"strings"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/plugins"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
"github.com/google/generative-ai-go/genai"
@@ -60,7 +60,7 @@ func (o *Client) ListModels() (ret []string, err error) {
return
}
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs)
var client *genai.Client
@@ -91,7 +91,7 @@ func (o *Client) buildModelNameFull(modelName string) string {
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
}
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
@@ -147,7 +147,7 @@ func (o *Client) NeedsRawMode(modelName string) bool {
return false
}
func toMessages(msgs []*goopenai.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
if len(msgs) >= 2 {
systemInstruction = &genai.Content{
Parts: []genai.Part{

View File

@@ -9,7 +9,7 @@ import (
"io"
"net/http"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
@@ -87,7 +87,7 @@ func (c *Client) ListModels() ([]string, error) {
return models, nil
}
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
payload := map[string]interface{}{
@@ -173,7 +173,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
return
}
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
payload := map[string]interface{}{

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/danielmiessler/fabric/chat"
ollamaapi "github.com/ollama/ollama/api"
"github.com/samber/lo"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
@@ -97,7 +97,7 @@ func (o *Client) ListModels() (ret []string, err error) {
return
}
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
req := o.createChatRequest(msgs, opts)
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
@@ -115,7 +115,7 @@ func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
return
}
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
bf := false
req := o.createChatRequest(msgs, opts)
@@ -132,8 +132,8 @@ func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
return
}
func (o *Client) createChatRequest(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
messages := lo.Map(msgs, func(message *goopenai.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
func (o *Client) createChatRequest(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
messages := lo.Map(msgs, func(message *chat.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
return ollamaapi.Message{Role: message.Role, Content: message.Content}
})

View File

@@ -2,16 +2,16 @@ package openai
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
goopenai "github.com/sashabaranov/go-openai"
openai "github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
)
func NewClient() (ret *Client) {
@@ -48,73 +48,53 @@ type Client struct {
*plugins.PluginBase
ApiKey *plugins.SetupQuestion
ApiBaseURL *plugins.SetupQuestion
ApiClient *goopenai.Client
ApiClient *openai.Client
}
func (o *Client) configure() (ret error) {
config := goopenai.DefaultConfig(o.ApiKey.Value)
opts := []option.RequestOption{option.WithAPIKey(o.ApiKey.Value)}
if o.ApiBaseURL.Value != "" {
config.BaseURL = o.ApiBaseURL.Value
opts = append(opts, option.WithBaseURL(o.ApiBaseURL.Value))
}
o.ApiClient = goopenai.NewClientWithConfig(config)
client := openai.NewClient(opts...)
o.ApiClient = &client
return
}
func (o *Client) ListModels() (ret []string, err error) {
var models goopenai.ModelsList
if models, err = o.ApiClient.ListModels(context.Background()); err != nil {
var page *pagination.Page[openai.Model]
if page, err = o.ApiClient.Models.List(context.Background()); err != nil {
return
}
model := models.Models
for _, mod := range model {
for _, mod := range page.Data {
ret = append(ret, mod.ID)
}
return
}
func (o *Client) SendStream(
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
) (err error) {
req := o.buildChatCompletionRequest(msgs, opts)
req.Stream = true
var stream *goopenai.ChatCompletionStream
if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
return
}
defer stream.Close()
for {
var response goopenai.ChatCompletionStreamResponse
if response, err = stream.Recv(); err == nil {
if len(response.Choices) > 0 {
channel <- response.Choices[0].Delta.Content
} else {
channel <- "\n"
close(channel)
break
}
} else if errors.Is(err, io.EOF) {
channel <- "\n"
close(channel)
err = nil
break
} else if err != nil {
fmt.Printf("\nStream error: %v\n", err)
break
req := o.buildChatCompletionParams(msgs, opts)
stream := o.ApiClient.Chat.Completions.NewStreaming(context.Background(), req)
for stream.Next() {
chunk := stream.Current()
if len(chunk.Choices) > 0 {
channel <- chunk.Choices[0].Delta.Content
}
}
return
if stream.Err() == nil {
channel <- "\n"
}
close(channel)
return stream.Err()
}
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
req := o.buildChatCompletionRequest(msgs, opts)
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
req := o.buildChatCompletionParams(msgs, opts)
var resp goopenai.ChatCompletionResponse
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
var resp *openai.ChatCompletion
if resp, err = o.ApiClient.Chat.Completions.New(ctx, req); err != nil {
return
}
if len(resp.Choices) > 0 {
@@ -146,57 +126,56 @@ func (o *Client) NeedsRawMode(modelName string) bool {
return slices.Contains(openAIModelsNeedingRaw, modelName)
}
func (o *Client) buildChatCompletionRequest(
inputMsgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions,
) (ret goopenai.ChatCompletionRequest) {
func (o *Client) buildChatCompletionParams(
inputMsgs []*chat.ChatCompletionMessage, opts *common.ChatOptions,
) (ret openai.ChatCompletionNewParams) {
// Create a new slice for messages to be sent, converting from []*Msg to []Msg.
// This also serves as a mutable copy for provider-specific modifications.
messagesForRequest := make([]goopenai.ChatCompletionMessage, len(inputMsgs))
messagesForRequest := make([]openai.ChatCompletionMessageParamUnion, len(inputMsgs))
for i, msgPtr := range inputMsgs {
messagesForRequest[i] = *msgPtr // Dereference and copy
}
// Provider-specific modification for DeepSeek:
// DeepSeek requires the last message to be a user message.
// If fabric constructs a single system message (common when a pattern includes user input),
// we change its role to user for DeepSeek.
if strings.Contains(opts.Model, "deepseek") { // Heuristic to identify DeepSeek models
if len(messagesForRequest) == 1 && messagesForRequest[0].Role == goopenai.ChatMessageRoleSystem {
messagesForRequest[0].Role = goopenai.ChatMessageRoleUser
msg := *msgPtr // copy
// Provider-specific modification for DeepSeek:
if strings.Contains(opts.Model, "deepseek") && len(inputMsgs) == 1 && msg.Role == chat.ChatMessageRoleSystem {
msg.Role = chat.ChatMessageRoleUser
}
// Note: This handles the most common case arising from pattern usage.
// More complex scenarios where a multi-message sequence ends in 'system'
// are not currently expected from chatter.go's BuildSession logic for OpenAI providers
// but might require further rules if they arise.
messagesForRequest[i] = convertMessage(msg)
}
if opts.Raw {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Messages: messagesForRequest,
}
} else {
if opts.Seed == 0 {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messagesForRequest,
}
} else {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messagesForRequest,
Seed: &opts.Seed,
}
ret = openai.ChatCompletionNewParams{
Model: openai.ChatModel(opts.Model),
Messages: messagesForRequest,
}
if !opts.Raw {
ret.Temperature = openai.Float(opts.Temperature)
ret.TopP = openai.Float(opts.TopP)
ret.PresencePenalty = openai.Float(opts.PresencePenalty)
ret.FrequencyPenalty = openai.Float(opts.FrequencyPenalty)
if opts.Seed != 0 {
ret.Seed = openai.Int(int64(opts.Seed))
}
}
return
}
func convertMessage(msg chat.ChatCompletionMessage) openai.ChatCompletionMessageParamUnion {
switch msg.Role {
case chat.ChatMessageRoleSystem:
return openai.SystemMessage(msg.Content)
case chat.ChatMessageRoleUser:
if len(msg.MultiContent) > 0 {
var parts []openai.ChatCompletionContentPartUnionParam
for _, p := range msg.MultiContent {
switch p.Type {
case chat.ChatMessagePartTypeText:
parts = append(parts, openai.TextContentPart(p.Text))
case chat.ChatMessagePartTypeImageURL:
parts = append(parts, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{URL: p.ImageURL.URL}))
}
}
return openai.UserMessage(parts)
}
return openai.UserMessage(msg.Content)
default:
return openai.AssistantMessage(msg.Content)
}
}

View File

@@ -3,18 +3,18 @@ package openai
import (
"testing"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/sashabaranov/go-openai"
goopenai "github.com/sashabaranov/go-openai"
openai "github.com/openai/openai-go"
"github.com/stretchr/testify/assert"
)
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
var msgs []*goopenai.ChatCompletionMessage
var msgs []*chat.ChatCompletionMessage
for i := 0; i < 2; i++ {
msgs = append(msgs, &goopenai.ChatCompletionMessage{
msgs = append(msgs, &chat.ChatCompletionMessage{
Role: "User",
Content: "My msg",
})
@@ -29,38 +29,22 @@ func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
Seed: 1,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: &opts.Seed,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
request := client.buildChatCompletionParams(msgs, opts)
assert.Equal(t, openai.ChatModel(opts.Model), request.Model)
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty)
assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty)
assert.Equal(t, openai.Int(int64(opts.Seed)), request.Seed)
}
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
var msgs []*goopenai.ChatCompletionMessage
var msgs []*chat.ChatCompletionMessage
for i := 0; i < 2; i++ {
msgs = append(msgs, &goopenai.ChatCompletionMessage{
msgs = append(msgs, &chat.ChatCompletionMessage{
Role: "User",
Content: "My msg",
})
@@ -75,28 +59,12 @@ func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
Seed: 0,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: nil,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
request := client.buildChatCompletionParams(msgs, opts)
assert.Equal(t, openai.ChatModel(opts.Model), request.Model)
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
assert.Equal(t, openai.Float(opts.PresencePenalty), request.PresencePenalty)
assert.Equal(t, openai.Float(opts.FrequencyPenalty), request.FrequencyPenalty)
assert.False(t, request.Seed.Valid())
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/danielmiessler/fabric/plugins"
perplexity "github.com/sgaunet/perplexity-go/v2"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/chat"
)
const (
@@ -61,7 +61,7 @@ func (c *Client) ListModels() ([]string, error) {
return models, nil
}
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
if c.client == nil {
if err := c.Configure(); err != nil {
return "", fmt.Errorf("failed to configure Perplexity client: %w", err)
@@ -120,7 +120,7 @@ func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
return content, nil
}
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
if c.client == nil {
if err := c.Configure(); err != nil {
close(channel) // Ensure channel is closed on error

View File

@@ -3,8 +3,8 @@ package ai
import (
"context"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/plugins"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
)
@@ -12,7 +12,7 @@ import (
type Vendor interface {
plugins.Plugin
ListModels() ([]string, error)
SendStream([]*goopenai.ChatCompletionMessage, *common.ChatOptions, chan string) error
Send(context.Context, []*goopenai.ChatCompletionMessage, *common.ChatOptions) (string, error)
SendStream([]*chat.ChatCompletionMessage, *common.ChatOptions, chan string) error
Send(context.Context, []*chat.ChatCompletionMessage, *common.ChatOptions) (string, error)
NeedsRawMode(modelName string) bool
}

View File

@@ -3,8 +3,8 @@ package fsdb
import (
"fmt"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
goopenai "github.com/sashabaranov/go-openai"
)
type SessionsEntity struct {
@@ -38,16 +38,16 @@ func (o *SessionsEntity) SaveSession(session *Session) (err error) {
type Session struct {
Name string
Messages []*goopenai.ChatCompletionMessage
Messages []*chat.ChatCompletionMessage
vendorMessages []*goopenai.ChatCompletionMessage
vendorMessages []*chat.ChatCompletionMessage
}
func (o *Session) IsEmpty() bool {
return len(o.Messages) == 0
}
func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
func (o *Session) Append(messages ...*chat.ChatCompletionMessage) {
if o.vendorMessages != nil {
for _, message := range messages {
o.Messages = append(o.Messages, message)
@@ -58,7 +58,7 @@ func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
}
}
func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
func (o *Session) GetVendorMessages() (ret []*chat.ChatCompletionMessage) {
if len(o.vendorMessages) == 0 {
for _, message := range o.Messages {
o.appendVendorMessage(message)
@@ -68,13 +68,13 @@ func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
return
}
func (o *Session) appendVendorMessage(message *goopenai.ChatCompletionMessage) {
func (o *Session) appendVendorMessage(message *chat.ChatCompletionMessage) {
if message.Role != common.ChatMessageRoleMeta {
o.vendorMessages = append(o.vendorMessages, message)
}
}
func (o *Session) GetLastMessage() (ret *goopenai.ChatCompletionMessage) {
func (o *Session) GetLastMessage() (ret *chat.ChatCompletionMessage) {
if len(o.Messages) > 0 {
ret = o.Messages[len(o.Messages)-1]
}
@@ -86,9 +86,9 @@ func (o *Session) String() (ret string) {
ret += fmt.Sprintf("\n--- \n[%v]\n%v", message.Role, message.Content)
if message.MultiContent != nil {
for _, part := range message.MultiContent {
if part.Type == goopenai.ChatMessagePartTypeImageURL {
if part.Type == chat.ChatMessagePartTypeImageURL {
ret += fmt.Sprintf("\n%v: %v", part.Type, *part.ImageURL)
} else if part.Type == goopenai.ChatMessagePartTypeText {
} else if part.Type == chat.ChatMessagePartTypeText {
ret += fmt.Sprintf("\n%v: %v", part.Type, part.Text)
}
}

View File

@@ -3,7 +3,7 @@ package fsdb
import (
"testing"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/chat"
)
func TestSessions_GetOrCreateSession(t *testing.T) {
@@ -27,7 +27,7 @@ func TestSessions_SaveSession(t *testing.T) {
StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"},
}
sessionName := "testSession"
session := &Session{Name: sessionName, Messages: []*goopenai.ChatCompletionMessage{{Content: "message1"}}}
session := &Session{Name: sessionName, Messages: []*chat.ChatCompletionMessage{{Content: "message1"}}}
err := sessions.SaveSession(session)
if err != nil {
t.Fatalf("failed to save session: %v", err)

View File

@@ -93,7 +93,7 @@ func TestSysPlugin(t *testing.T) {
if !filepath.IsAbs(got) {
return fmt.Errorf("expected absolute path, got %s", got)
}
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") {
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") && got != "/root" {
return fmt.Errorf("path %s doesn't look like a home directory", got)
}
return nil