mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
feat: add image file validation and format detection for image generation
## CHANGES • Add image file path validation with extension checking • Implement dynamic output format detection from file extensions • Update BuildChatOptions method to return error for validation • Add comprehensive test coverage for image file validation • Upgrade YAML library from v2 to v3 • Update shell completions to reflect supported image formats • Add error handling for existing file conflicts • Support PNG, JPEG, JPG, and WEBP image formats
This commit is contained in:
@@ -270,7 +270,11 @@ func Cli(version string) (err error) {
|
||||
if chatReq.Language == "" {
|
||||
chatReq.Language = registry.Language.DefaultLanguage.Value
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, currentFlags.BuildChatOptions()); err != nil {
|
||||
var chatOptions *common.ChatOptions
|
||||
if chatOptions, err = currentFlags.BuildChatOptions(); err != nil {
|
||||
return
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, chatOptions); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
34
cli/flags.go
34
cli/flags.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/jessevdk/go-flags"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
|
||||
@@ -257,7 +258,36 @@ func readStdin() (ret string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
// validateImageFile validates the image file path and extension
|
||||
func validateImageFile(imagePath string) error {
|
||||
if imagePath == "" {
|
||||
return nil // No validation needed if no image file specified
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(imagePath); err == nil {
|
||||
return fmt.Errorf("image file already exists: %s", imagePath)
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
|
||||
for _, validExt := range validExtensions {
|
||||
if ext == validExt {
|
||||
return nil // Valid extension found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
// Validate image file if specified
|
||||
if err = validateImageFile(o.ImageFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = &common.ChatOptions{
|
||||
Temperature: o.Temperature,
|
||||
TopP: o.TopP,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -64,7 +65,8 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -84,7 +86,8 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -164,3 +167,91 @@ model: 123 # should be string
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageFile(t *testing.T) {
|
||||
t.Run("Empty path should be valid", func(t *testing.T) {
|
||||
err := validateImageFile("")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid extensions should pass", func(t *testing.T) {
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
for _, ext := range validExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.NoError(t, err, "Extension %s should be valid", ext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid extensions should fail", func(t *testing.T) {
|
||||
invalidExtensions := []string{".gif", ".bmp", ".tiff", ".svg", ".txt", ""}
|
||||
for _, ext := range invalidExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.Error(t, err, "Extension %s should be invalid", ext)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "test*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
// Validation should fail because file exists
|
||||
err = validateImageFile(tempFile.Name())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
|
||||
t.Run("Non-existing file with valid extension should pass", func(t *testing.T) {
|
||||
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.png")
|
||||
// Make sure the file doesn't exist
|
||||
os.Remove(nonExistentFile)
|
||||
|
||||
err := validateImageFile(nonExistentFile)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
|
||||
t.Run("Valid image file should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/tmp/output.png", options.ImageFile)
|
||||
})
|
||||
|
||||
t.Run("Invalid extension should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.gif",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "existing*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
flags := &Flags{
|
||||
ImageFile: tempFile.Name(),
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ _fabric() {
|
||||
'(--version)--version[Print current version]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.jpg *.jpeg *.gif *.bmp"' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--listextensions)--listextensions[List all registered extensions]' \
|
||||
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \
|
||||
|
||||
@@ -61,7 +61,7 @@ complete -c fabric -l address -d "The address to bind the REST API (default: :80
|
||||
complete -c fabric -l api-key -d "API key used to secure server routes"
|
||||
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.jpg *.jpeg *.gif *.bmp"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
|
||||
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
|
||||
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"
|
||||
|
||||
1
go.mod
1
go.mod
@@ -27,7 +27,6 @@ require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
|
||||
1
go.sum
1
go.sum
@@ -354,7 +354,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
BIN
mars-colony.png
Normal file
BIN
mars-colony.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/responses"
|
||||
@@ -17,15 +18,37 @@ import (
|
||||
const ImageGenerationResponseType = "image_generation_call"
|
||||
const ImageGenerationToolType = "image_generation"
|
||||
|
||||
// getOutputFormatFromExtension determines the API output format based on file extension
|
||||
func getOutputFormatFromExtension(imagePath string) string {
|
||||
if imagePath == "" {
|
||||
return "png" // Default format
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
switch ext {
|
||||
case ".png":
|
||||
return "png"
|
||||
case ".webp":
|
||||
return "webp"
|
||||
case ".jpg":
|
||||
return "jpeg"
|
||||
case ".jpeg":
|
||||
return "jpeg"
|
||||
default:
|
||||
return "png" // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// addImageGenerationTool adds the image generation tool to the request if needed
|
||||
func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []responses.ToolUnionParam) []responses.ToolUnionParam {
|
||||
// Check if the request seems to be asking for image generation
|
||||
if o.shouldUseImageGeneration(opts) {
|
||||
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
|
||||
imageGenTool := responses.ToolUnionParam{
|
||||
OfImageGeneration: &responses.ToolImageGenerationParam{
|
||||
Type: ImageGenerationToolType,
|
||||
Model: "gpt-image-1",
|
||||
OutputFormat: "png",
|
||||
OutputFormat: outputFormat,
|
||||
Quality: "auto",
|
||||
Size: "auto",
|
||||
},
|
||||
|
||||
@@ -112,3 +112,109 @@ func TestBuildResponseParams_WithBothSearchAndImage(t *testing.T) {
|
||||
assert.True(t, hasSearchTool, "Should have web search tool")
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestGetOutputFormatFromExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
imagePath string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG extension",
|
||||
imagePath: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP extension",
|
||||
imagePath: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG extension",
|
||||
imagePath: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG extension",
|
||||
imagePath: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Uppercase PNG extension",
|
||||
imagePath: "/tmp/output.PNG",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Mixed case JPEG extension",
|
||||
imagePath: "/tmp/output.JpEg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
imagePath: "",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "No extension",
|
||||
imagePath: "/tmp/output",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Unsupported extension",
|
||||
imagePath: "/tmp/output.gif",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getOutputFormatFromExtension(tt.imagePath)
|
||||
assert.Equal(t, tt.expectedFormat, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
imageFile string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG file",
|
||||
imageFile: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP file",
|
||||
imageFile: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG file",
|
||||
imageFile: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG file",
|
||||
imageFile: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: tt.imageFile,
|
||||
}
|
||||
|
||||
tools := client.addImageGenerationTool(opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1, "Should have one tool")
|
||||
assert.NotNil(t, tools[0].OfImageGeneration, "Should be image generation tool")
|
||||
assert.Equal(t, tt.expectedFormat, tools[0].OfImageGeneration.OutputFormat, "Output format should match file extension")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user