mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-15 08:25:16 -05:00
103 lines
2.4 KiB
Go
103 lines
2.4 KiB
Go
package openai
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/danielmiessler/fabric/common"
|
|
"github.com/sashabaranov/go-openai"
|
|
goopenai "github.com/sashabaranov/go-openai"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
|
|
|
var msgs []*goopenai.ChatCompletionMessage
|
|
|
|
for i := 0; i < 2; i++ {
|
|
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
|
Role: "User",
|
|
Content: "My msg",
|
|
})
|
|
}
|
|
|
|
opts := &common.ChatOptions{
|
|
Temperature: 0.8,
|
|
TopP: 0.9,
|
|
PresencePenalty: 0.1,
|
|
FrequencyPenalty: 0.2,
|
|
Raw: false,
|
|
Seed: 1,
|
|
}
|
|
|
|
var expectedMessages []openai.ChatCompletionMessage
|
|
|
|
for i := 0; i < 2; i++ {
|
|
expectedMessages = append(expectedMessages,
|
|
openai.ChatCompletionMessage{
|
|
Role: msgs[i].Role,
|
|
Content: msgs[i].Content,
|
|
},
|
|
)
|
|
}
|
|
|
|
var expectedRequest = goopenai.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Temperature: float32(opts.Temperature),
|
|
TopP: float32(opts.TopP),
|
|
PresencePenalty: float32(opts.PresencePenalty),
|
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
|
Messages: expectedMessages,
|
|
Seed: &opts.Seed,
|
|
}
|
|
|
|
var client = NewClient()
|
|
request := client.buildChatCompletionRequest(msgs, opts)
|
|
assert.Equal(t, expectedRequest, request)
|
|
}
|
|
|
|
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
|
|
|
var msgs []*goopenai.ChatCompletionMessage
|
|
|
|
for i := 0; i < 2; i++ {
|
|
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
|
Role: "User",
|
|
Content: "My msg",
|
|
})
|
|
}
|
|
|
|
opts := &common.ChatOptions{
|
|
Temperature: 0.8,
|
|
TopP: 0.9,
|
|
PresencePenalty: 0.1,
|
|
FrequencyPenalty: 0.2,
|
|
Raw: false,
|
|
Seed: 0,
|
|
}
|
|
|
|
var expectedMessages []openai.ChatCompletionMessage
|
|
|
|
for i := 0; i < 2; i++ {
|
|
expectedMessages = append(expectedMessages,
|
|
openai.ChatCompletionMessage{
|
|
Role: msgs[i].Role,
|
|
Content: msgs[i].Content,
|
|
},
|
|
)
|
|
}
|
|
|
|
var expectedRequest = goopenai.ChatCompletionRequest{
|
|
Model: opts.Model,
|
|
Temperature: float32(opts.Temperature),
|
|
TopP: float32(opts.TopP),
|
|
PresencePenalty: float32(opts.PresencePenalty),
|
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
|
Messages: expectedMessages,
|
|
Seed: nil,
|
|
}
|
|
|
|
var client = NewClient()
|
|
request := client.buildChatCompletionRequest(msgs, opts)
|
|
assert.Equal(t, expectedRequest, request)
|
|
}
|