Files
Fabric/plugins/ai/openai/openai_test.go
2024-10-29 20:38:15 +01:00

103 lines
2.4 KiB
Go

package openai
import (
"testing"
"github.com/danielmiessler/fabric/common"
"github.com/sashabaranov/go-openai"
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
var msgs []*goopenai.ChatCompletionMessage
for i := 0; i < 2; i++ {
msgs = append(msgs, &goopenai.ChatCompletionMessage{
Role: "User",
Content: "My msg",
})
}
opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 1,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: &opts.Seed,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
var msgs []*goopenai.ChatCompletionMessage
for i := 0; i < 2; i++ {
msgs = append(msgs, &goopenai.ChatCompletionMessage{
Role: "User",
Content: "My msg",
})
}
opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 0,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: nil,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}