mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d081fd269c | ||
|
|
369a0a850d | ||
|
|
8dc5343ee6 | ||
|
|
eda552dac5 | ||
|
|
f13a56685b | ||
|
|
2f9afe0247 | ||
|
|
1ec525ad97 | ||
|
|
b7dc6748e0 | ||
|
|
f1b612d828 | ||
|
|
eac5a104f2 | ||
|
|
4bff88fae3 | ||
|
|
acf1be71ce | ||
|
|
236a3c5f38 | ||
|
|
b2418984f8 | ||
|
|
152d74d160 | ||
|
|
4e16bbccd8 | ||
|
|
60174f41a4 | ||
|
|
ad4683952e | ||
|
|
86a044735b |
167
README.md
167
README.md
@@ -114,13 +114,11 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
>
|
||||
> July 4, 2025
|
||||
>
|
||||
> - Fabric now supports web search using the `--search` and `--search-location` flags
|
||||
> - Web search is available for both Anthropic and OpenAI providers
|
||||
> - Previous plugin-level search configurations have been removed in favor of the new flag-based approach.
|
||||
> - If you used the previous approach, consider cleaning up your `~/.config/fabric/.env` file, removing the unused `ANTHROPIC_WEB_SEARCH_TOOL_ENABLED` and `ANTHROPIC_WEB_SEARCH_TOOL_LOCATION` variables.
|
||||
> - Fabric now supports image generation using the `--image-file` flag with OpenAI models
|
||||
> - Image generation works with both text prompts and input images (via `--attachment`) for image editing tasks
|
||||
>
|
||||
> - **Web Search**: Fabric now supports web search for Anthropic and OpenAI models using the `--search` and `--search-location` flags. This replaces the previous plugin-based search, so you may want to remove the old `ANTHROPIC_WEB_SEARCH_TOOL_*` variables from your `~/.config/fabric/.env` file.
|
||||
> - **Image Generation**: Fabric now has powerful image generation capabilities with OpenAI.
|
||||
> - Generate images from text prompts and save them using `--image-file`.
|
||||
> - Edit existing images by providing an input image with `--attachment`.
|
||||
> - Control image `size`, `quality`, `compression`, and `background` with the new `--image-*` flags.
|
||||
>
|
||||
>June 17, 2025
|
||||
>
|
||||
@@ -292,88 +290,88 @@ yt() {
|
||||
|
||||
You can add the below code for the equivalent aliases inside PowerShell by running `notepad $PROFILE` inside a PowerShell window:
|
||||
|
||||
```powershell
|
||||
# Path to the patterns directory
|
||||
$patternsPath = Join-Path $HOME ".config/fabric/patterns"
|
||||
foreach ($patternDir in Get-ChildItem -Path $patternsPath -Directory) {
|
||||
$patternName = $patternDir.Name
|
||||
```powershell
|
||||
# Path to the patterns directory
|
||||
$patternsPath = Join-Path $HOME ".config/fabric/patterns"
|
||||
foreach ($patternDir in Get-ChildItem -Path $patternsPath -Directory) {
|
||||
$patternName = $patternDir.Name
|
||||
|
||||
# Dynamically define a function for each pattern
|
||||
$functionDefinition = @"
|
||||
function $patternName {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter(ValueFromPipeline = `$true)]
|
||||
[string] `$InputObject,
|
||||
# Dynamically define a function for each pattern
|
||||
$functionDefinition = @"
|
||||
function $patternName {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter(ValueFromPipeline = `$true)]
|
||||
[string] `$InputObject,
|
||||
|
||||
[Parameter(ValueFromRemainingArguments = `$true)]
|
||||
[String[]] `$patternArgs
|
||||
)
|
||||
[Parameter(ValueFromRemainingArguments = `$true)]
|
||||
[String[]] `$patternArgs
|
||||
)
|
||||
|
||||
begin {
|
||||
# Initialize an array to collect pipeline input
|
||||
`$collector = @()
|
||||
}
|
||||
|
||||
process {
|
||||
# Collect pipeline input objects
|
||||
if (`$InputObject) {
|
||||
`$collector += `$InputObject
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
# Join all pipeline input into a single string, separated by newlines
|
||||
`$pipelineContent = `$collector -join "`n"
|
||||
|
||||
# If there's pipeline input, include it in the call to fabric
|
||||
if (`$pipelineContent) {
|
||||
`$pipelineContent | fabric --pattern $patternName `$patternArgs
|
||||
} else {
|
||||
# No pipeline input; just call fabric with the additional args
|
||||
fabric --pattern $patternName `$patternArgs
|
||||
}
|
||||
}
|
||||
}
|
||||
"@
|
||||
# Add the function to the current session
|
||||
Invoke-Expression $functionDefinition
|
||||
begin {
|
||||
# Initialize an array to collect pipeline input
|
||||
`$collector = @()
|
||||
}
|
||||
|
||||
# Define the 'yt' function as well
|
||||
function yt {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter()]
|
||||
[Alias("timestamps")]
|
||||
[switch]$t,
|
||||
|
||||
[Parameter(Position = 0, ValueFromPipeline = $true)]
|
||||
[string]$videoLink
|
||||
)
|
||||
|
||||
begin {
|
||||
$transcriptFlag = "--transcript"
|
||||
if ($t) {
|
||||
$transcriptFlag = "--transcript-with-timestamps"
|
||||
}
|
||||
}
|
||||
|
||||
process {
|
||||
if (-not $videoLink) {
|
||||
Write-Error "Usage: yt [-t | --timestamps] youtube-link"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
if ($videoLink) {
|
||||
# Execute and allow output to flow through the pipeline
|
||||
fabric -y $videoLink $transcriptFlag
|
||||
}
|
||||
process {
|
||||
# Collect pipeline input objects
|
||||
if (`$InputObject) {
|
||||
`$collector += `$InputObject
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
end {
|
||||
# Join all pipeline input into a single string, separated by newlines
|
||||
`$pipelineContent = `$collector -join "`n"
|
||||
|
||||
# If there's pipeline input, include it in the call to fabric
|
||||
if (`$pipelineContent) {
|
||||
`$pipelineContent | fabric --pattern $patternName `$patternArgs
|
||||
} else {
|
||||
# No pipeline input; just call fabric with the additional args
|
||||
fabric --pattern $patternName `$patternArgs
|
||||
}
|
||||
}
|
||||
}
|
||||
"@
|
||||
# Add the function to the current session
|
||||
Invoke-Expression $functionDefinition
|
||||
}
|
||||
|
||||
# Define the 'yt' function as well
|
||||
function yt {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter()]
|
||||
[Alias("timestamps")]
|
||||
[switch]$t,
|
||||
|
||||
[Parameter(Position = 0, ValueFromPipeline = $true)]
|
||||
[string]$videoLink
|
||||
)
|
||||
|
||||
begin {
|
||||
$transcriptFlag = "--transcript"
|
||||
if ($t) {
|
||||
$transcriptFlag = "--transcript-with-timestamps"
|
||||
}
|
||||
}
|
||||
|
||||
process {
|
||||
if (-not $videoLink) {
|
||||
Write-Error "Usage: yt [-t | --timestamps] youtube-link"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
if ($videoLink) {
|
||||
# Execute and allow output to flow through the pipeline
|
||||
fabric -y $videoLink $transcriptFlag
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This also creates a `yt` alias that allows you to use `yt https://www.youtube.com/watch?v=4b0iet22VIk` to get transcripts, comments, and metadata.
|
||||
|
||||
@@ -559,6 +557,11 @@ Application Options:
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI)
|
||||
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
|
||||
--image-file= Save generated image to specified file path (e.g., 'output.png')
|
||||
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)
|
||||
--image-quality= Image quality: low, medium, high, auto (default: auto)
|
||||
--image-compression= Compression level 0-100 for JPEG/WebP formats (default: not set)
|
||||
--image-background= Background type: opaque, transparent (default: opaque, only for
|
||||
PNG/WebP)
|
||||
|
||||
Help Options:
|
||||
-h, --help Show this help message
|
||||
|
||||
92
cli/flags.go
92
cli/flags.go
@@ -78,6 +78,10 @@ type Flags struct {
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
|
||||
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`
|
||||
ImageQuality string `long:"image-quality" description:"Image quality: low, medium, high, auto (default: auto)"`
|
||||
ImageCompression int `long:"image-compression" description:"Compression level 0-100 for JPEG/WebP formats (default: not set)"`
|
||||
ImageBackground string `long:"image-background" description:"Background type: opaque, transparent (default: opaque, only for PNG/WebP)"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
@@ -282,13 +286,97 @@ func validateImageFile(imagePath string) error {
|
||||
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
|
||||
}
|
||||
|
||||
// validateImageParameters validates image generation parameters
|
||||
func validateImageParameters(imagePath, size, quality, background string, compression int) error {
|
||||
if imagePath == "" {
|
||||
// Check if any image parameters are specified without --image-file
|
||||
if size != "" || quality != "" || background != "" || compression != 0 {
|
||||
return fmt.Errorf("image parameters (--image-size, --image-quality, --image-background, --image-compression) can only be used with --image-file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if size != "" {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
valid := false
|
||||
for _, validSize := range validSizes {
|
||||
if size == validSize {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image size '%s'. Supported sizes: 1024x1024, 1536x1024, 1024x1536, auto", size)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate quality
|
||||
if quality != "" {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
valid := false
|
||||
for _, validQuality := range validQualities {
|
||||
if quality == validQuality {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image quality '%s'. Supported qualities: low, medium, high, auto", quality)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background
|
||||
if background != "" {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
valid := false
|
||||
for _, validBackground := range validBackgrounds {
|
||||
if background == validBackground {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image background '%s'. Supported backgrounds: opaque, transparent", background)
|
||||
}
|
||||
}
|
||||
|
||||
// Get file format for format-specific validations
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
|
||||
// Validate compression (only for jpeg/webp)
|
||||
if compression != 0 { // 0 means not set
|
||||
if ext != ".jpg" && ext != ".jpeg" && ext != ".webp" {
|
||||
return fmt.Errorf("image compression can only be used with JPEG and WebP formats, not %s", ext)
|
||||
}
|
||||
if compression < 0 || compression > 100 {
|
||||
return fmt.Errorf("image compression must be between 0 and 100, got %d", compression)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background transparency (only for png/webp)
|
||||
if background == "transparent" {
|
||||
if ext != ".png" && ext != ".webp" {
|
||||
return fmt.Errorf("transparent background can only be used with PNG and WebP formats, not %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
// Validate image file if specified
|
||||
if err = validateImageFile(o.ImageFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate image parameters
|
||||
if err = validateImageParameters(o.ImageFile, o.ImageSize, o.ImageQuality, o.ImageBackground, o.ImageCompression); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = &common.ChatOptions{
|
||||
Model: o.Model,
|
||||
Temperature: o.Temperature,
|
||||
TopP: o.TopP,
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
@@ -299,6 +387,10 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
Search: o.Search,
|
||||
SearchLocation: o.SearchLocation,
|
||||
ImageFile: o.ImageFile,
|
||||
ImageSize: o.ImageSize,
|
||||
ImageQuality: o.ImageQuality,
|
||||
ImageCompression: o.ImageCompression,
|
||||
ImageBackground: o.ImageBackground,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -255,3 +255,181 @@ func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageParameters(t *testing.T) {
|
||||
t.Run("No image file and no parameters should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("", "", "", "", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail", func(t *testing.T) {
|
||||
// Test each parameter individually
|
||||
err := validateImageParameters("", "1024x1024", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
|
||||
err = validateImageParameters("", "", "high", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
// Test multiple parameters
|
||||
err = validateImageParameters("", "1024x1024", "high", "transparent", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
})
|
||||
|
||||
t.Run("Valid size values should pass", func(t *testing.T) {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
for _, size := range validSizes {
|
||||
err := validateImageParameters("/tmp/test.png", size, "", "", 0)
|
||||
assert.NoError(t, err, "Size %s should be valid", size)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid size should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "invalid", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("Valid quality values should pass", func(t *testing.T) {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
for _, quality := range validQualities {
|
||||
err := validateImageParameters("/tmp/test.png", "", quality, "", 0)
|
||||
assert.NoError(t, err, "Quality %s should be valid", quality)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid quality should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "invalid", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image quality")
|
||||
})
|
||||
|
||||
t.Run("Valid background values should pass", func(t *testing.T) {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
for _, background := range validBackgrounds {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", background, 0)
|
||||
assert.NoError(t, err, "Background %s should be valid", background)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid background should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "invalid", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image background")
|
||||
})
|
||||
|
||||
t.Run("Compression for JPEG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 75)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "", 50)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for PNG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "", 75)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression can only be used with JPEG and WebP formats")
|
||||
})
|
||||
|
||||
t.Run("Invalid compression range should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 150)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
|
||||
err = validateImageParameters("/tmp/test.jpg", "", "", "", -10)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
})
|
||||
|
||||
t.Run("Transparent background for PNG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for JPEG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transparent background can only be used with PNG and WebP formats")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageParameters(t *testing.T) {
|
||||
t.Run("Valid image parameters should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not set for PNG
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, "/tmp/test.png", options.ImageFile)
|
||||
assert.Equal(t, "1024x1024", options.ImageSize)
|
||||
assert.Equal(t, "high", options.ImageQuality)
|
||||
assert.Equal(t, "transparent", options.ImageBackground)
|
||||
assert.Equal(t, 0, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Invalid image parameters should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "invalid",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("JPEG with compression should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 80,
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, 80, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail in BuildChatOptions", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageSize: "1024x1024", // Image parameter without ImageFile
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ type ChatOptions struct {
|
||||
Search bool
|
||||
SearchLocation string
|
||||
ImageFile string
|
||||
ImageSize string
|
||||
ImageQuality string
|
||||
ImageCompression int
|
||||
ImageBackground string
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
||||
124
common/oauth_storage.go
Normal file
124
common/oauth_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthToken represents stored OAuth token information
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the token is expired or will expire within the buffer time
|
||||
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
|
||||
if t.ExpiresAt == 0 {
|
||||
return true
|
||||
}
|
||||
bufferTime := time.Duration(bufferMinutes) * time.Minute
|
||||
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
|
||||
}
|
||||
|
||||
// OAuthStorage handles persistent storage of OAuth tokens
|
||||
type OAuthStorage struct {
|
||||
configDir string
|
||||
}
|
||||
|
||||
// NewOAuthStorage creates a new OAuth storage instance
|
||||
func NewOAuthStorage() (*OAuthStorage, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".config", "fabric")
|
||||
|
||||
// Ensure config directory exists
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return &OAuthStorage{configDir: configDir}, nil
|
||||
}
|
||||
|
||||
// GetTokenPath returns the file path for a provider's OAuth token
|
||||
func (s *OAuthStorage) GetTokenPath(provider string) string {
|
||||
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
|
||||
}
|
||||
|
||||
// SaveToken saves an OAuth token to disk with proper permissions
|
||||
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Marshal token to JSON
|
||||
data, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first for atomic operation
|
||||
tempPath := tokenPath + ".tmp"
|
||||
if err := os.WriteFile(tempPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, tokenPath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to save token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToken loads an OAuth token from disk
|
||||
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
return nil, nil // No token stored
|
||||
}
|
||||
|
||||
// Read token file
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal token
|
||||
var token OAuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken removes a stored OAuth token
|
||||
func (s *OAuthStorage) DeleteToken(provider string) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasValidToken checks if a valid (non-expired) token exists for a provider
|
||||
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
|
||||
token, err := s.LoadToken(provider)
|
||||
if err != nil || token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !token.IsExpired(bufferMinutes)
|
||||
}
|
||||
232
common/oauth_storage_test.go
Normal file
232
common/oauth_storage_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOAuthToken_IsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt int64
|
||||
bufferMinutes int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "token not expired",
|
||||
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
bufferMinutes: 5,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "token expired",
|
||||
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "token expires within buffer",
|
||||
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry time",
|
||||
expiresAt: 0,
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &OAuthToken{ExpiresAt: tt.expiresAt}
|
||||
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create storage with custom config dir
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_access_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "test_scope",
|
||||
}
|
||||
|
||||
// Test saving token
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and has correct permissions
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
info, err := os.Stat(tokenPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Token file not created: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
|
||||
}
|
||||
|
||||
// Test loading token
|
||||
loadedToken, err := storage.LoadToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load token: %v", err)
|
||||
}
|
||||
if loadedToken == nil {
|
||||
t.Fatal("Loaded token is nil")
|
||||
}
|
||||
|
||||
// Verify token data
|
||||
if loadedToken.AccessToken != token.AccessToken {
|
||||
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
if loadedToken.RefreshToken != token.RefreshToken {
|
||||
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
|
||||
}
|
||||
if loadedToken.ExpiresAt != token.ExpiresAt {
|
||||
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Try to load non-existent token
|
||||
token, err := storage.LoadToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error loading non-existent token: %v", err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Error("Expected nil token for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_DeleteToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Create and save a token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
}
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token exists
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
t.Fatal("Token file should exist before deletion")
|
||||
}
|
||||
|
||||
// Delete token
|
||||
err = storage.DeleteToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token is deleted
|
||||
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
|
||||
t.Error("Token file should not exist after deletion")
|
||||
}
|
||||
|
||||
// Test deleting non-existent token (should not error)
|
||||
err = storage.DeleteToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Deleting non-existent token should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_HasValidToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test with no token
|
||||
if storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return false when no token exists")
|
||||
}
|
||||
|
||||
// Save valid token
|
||||
validToken := &OAuthToken{
|
||||
AccessToken: "valid_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
}
|
||||
err = storage.SaveToken("test_provider", validToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save valid token: %v", err)
|
||||
}
|
||||
|
||||
// Test with valid token
|
||||
if !storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return true for valid token")
|
||||
}
|
||||
|
||||
// Save expired token
|
||||
expiredToken := &OAuthToken{
|
||||
AccessToken: "expired_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
}
|
||||
err = storage.SaveToken("expired_provider", expiredToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save expired token: %v", err)
|
||||
}
|
||||
|
||||
// Test with expired token
|
||||
if storage.HasValidToken("expired_provider", 5) {
|
||||
t.Error("Should return false for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_GetTokenPath(t *testing.T) {
|
||||
storage := &OAuthStorage{configDir: "/test/config"}
|
||||
|
||||
expected := filepath.Join("/test/config", ".test_provider_oauth")
|
||||
actual := storage.GetTokenPath("test_provider")
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
@@ -99,6 +99,10 @@ _fabric() {
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \
|
||||
'(--image-quality)--image-quality[Image quality]:quality:(low medium high auto)' \
|
||||
'(--image-compression)--image-compression[Compression level 0-100 for JPEG/WebP formats]:compression:' \
|
||||
'(--image-background)--image-background[Background type]:background:(opaque transparent)' \
|
||||
'(--listextensions)--listextensions[List all registered extensions]' \
|
||||
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \
|
||||
|
||||
@@ -13,7 +13,7 @@ _fabric() {
|
||||
_get_comp_words_by_ref -n : cur prev words cword
|
||||
|
||||
# Define all possible options/flags
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --image-size --image-quality --image-compression --image-background --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
|
||||
# Helper function for dynamic completions
|
||||
_fabric_get_list() {
|
||||
@@ -67,8 +67,21 @@ _fabric() {
|
||||
_filedir
|
||||
return 0
|
||||
;;
|
||||
# Image generation options with specific values
|
||||
--image-size)
|
||||
COMPREPLY=($(compgen -W "1024x1024 1536x1024 1024x1536 auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-quality)
|
||||
COMPREPLY=($(compgen -W "low medium high auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-background)
|
||||
COMPREPLY=($(compgen -W "opaque transparent" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
# Options requiring simple arguments (no specific completion logic here)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location | --image-compression)
|
||||
# No specific completion suggestions, user types the value
|
||||
return 0
|
||||
;;
|
||||
|
||||
@@ -62,6 +62,10 @@ complete -c fabric -l api-key -d "API key used to secure server routes"
|
||||
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
|
||||
complete -c fabric -l image-size -d "Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)" -a "1024x1024 1536x1024 1024x1536 auto"
|
||||
complete -c fabric -l image-quality -d "Image quality: low, medium, high, auto (default: auto)" -a "low medium high auto"
|
||||
complete -c fabric -l image-compression -d "Compression level 0-100 for JPEG/WebP formats (default: not set)" -r
|
||||
complete -c fabric -l image-background -d "Background type: opaque, transparent (default: opaque, only for PNG/WebP)" -a "opaque transparent"
|
||||
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
|
||||
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"
|
||||
|
||||
2
go.mod
2
go.mod
@@ -25,6 +25,7 @@ require (
|
||||
github.com/samber/lo v1.50.0
|
||||
github.com/sgaunet/perplexity-go/v2 v2.8.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -108,7 +109,6 @@ require (
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
|
||||
BIN
mars-colony.png
BIN
mars-colony.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.8 MiB |
@@ -1 +1 @@
|
||||
"1.4.228"
|
||||
"1.4.231"
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
@@ -30,7 +31,12 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
if plugins.ParseBoolElseFalse(ret.UseOAuth.Value) {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
} else {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
}
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -50,6 +56,7 @@ type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -58,24 +65,50 @@ type Client struct {
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
baseURL := an.ApiBaseURL.Value
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// As of 2.0beta1, using v2 API endpoint.
|
||||
// https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md#020-beta1-2025-03-25
|
||||
if strings.Contains(baseURL, "-") && !strings.HasSuffix(baseURL, "/v2") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
baseURL = baseURL + "/v2"
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(
|
||||
option.WithAPIKey(an.ApiKey.Value),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
} else {
|
||||
an.client = anthropic.NewClient(option.WithAPIKey(an.ApiKey.Value))
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
opts := []option.RequestOption{}
|
||||
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,6 +157,17 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
@@ -207,6 +251,9 @@ func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropi
|
||||
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
var systemContent string
|
||||
|
||||
// Note: Claude Code spoofing is now handled in buildMessageParams
|
||||
|
||||
isFirstUserMessage := true
|
||||
lastRoleWasUser := false
|
||||
|
||||
|
||||
300
plugins/ai/anthropic/oauth.go
Normal file
300
plugins/ai/anthropic/oauth.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -128,6 +128,11 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
func (o *Client) sendResponses(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
// Validate model supports image generation if image file is specified
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
return "", fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
}
|
||||
|
||||
req := o.buildResponseParams(msgs, opts)
|
||||
|
||||
var resp *responses.Response
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/responses"
|
||||
)
|
||||
|
||||
@@ -18,6 +19,26 @@ import (
|
||||
const ImageGenerationResponseType = "image_generation_call"
|
||||
const ImageGenerationToolType = "image_generation"
|
||||
|
||||
// ImageGenerationSupportedModels lists all models that support image generation
|
||||
var ImageGenerationSupportedModels = []string{
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"o3",
|
||||
}
|
||||
|
||||
// supportsImageGeneration checks if the given model supports the image_generation tool
|
||||
func supportsImageGeneration(model string) bool {
|
||||
for _, supportedModel := range ImageGenerationSupportedModels {
|
||||
if model == supportedModel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getOutputFormatFromExtension determines the API output format based on file extension
|
||||
func getOutputFormatFromExtension(imagePath string) string {
|
||||
if imagePath == "" {
|
||||
@@ -44,15 +65,36 @@ func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []respon
|
||||
// Check if the request seems to be asking for image generation
|
||||
if o.shouldUseImageGeneration(opts) {
|
||||
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
|
||||
|
||||
// Build the image generation tool with user parameters
|
||||
imageGenTool := responses.ToolUnionParam{
|
||||
OfImageGeneration: &responses.ToolImageGenerationParam{
|
||||
Type: ImageGenerationToolType,
|
||||
Model: "gpt-image-1",
|
||||
OutputFormat: outputFormat,
|
||||
Quality: "auto",
|
||||
Size: "auto",
|
||||
},
|
||||
}
|
||||
|
||||
// Set quality if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageQuality != "" {
|
||||
imageGenTool.OfImageGeneration.Quality = opts.ImageQuality
|
||||
}
|
||||
|
||||
// Set size if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageSize != "" {
|
||||
imageGenTool.OfImageGeneration.Size = opts.ImageSize
|
||||
}
|
||||
|
||||
// Set background if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageBackground != "" {
|
||||
imageGenTool.OfImageGeneration.Background = opts.ImageBackground
|
||||
}
|
||||
|
||||
// Set compression if specified by user (only for jpeg/webp)
|
||||
if opts.ImageCompression != 0 {
|
||||
imageGenTool.OfImageGeneration.OutputCompression = param.NewOpt(int64(opts.ImageCompression))
|
||||
}
|
||||
|
||||
tools = append(tools, imageGenTool)
|
||||
}
|
||||
return tools
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
@@ -218,3 +220,225 @@ func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsImageGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "gpt-4o supports image generation",
|
||||
model: "gpt-4o",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini supports image generation",
|
||||
model: "gpt-4o-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1 supports image generation",
|
||||
model: "gpt-4.1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-mini supports image generation",
|
||||
model: "gpt-4.1-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-nano supports image generation",
|
||||
model: "gpt-4.1-nano",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o3 supports image generation",
|
||||
model: "o3",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o1 does not support image generation",
|
||||
model: "o1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o1-mini does not support image generation",
|
||||
model: "o1-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o3-mini does not support image generation",
|
||||
model: "o3-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4 does not support image generation",
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-3.5-turbo does not support image generation",
|
||||
model: "gpt-3.5-turbo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model does not support image generation",
|
||||
model: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := supportsImageGeneration(tt.model)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelValidationLogic(t *testing.T) {
|
||||
t.Run("Unsupported model with image file should return validation error", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
err := fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
|
||||
assert.Contains(t, err.Error(), "does not support image generation")
|
||||
assert.Contains(t, err.Error(), "o1-mini")
|
||||
assert.Contains(t, err.Error(), "Supported models:")
|
||||
} else {
|
||||
t.Error("Expected validation to trigger")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Supported model with image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger for supported model")
|
||||
})
|
||||
|
||||
t.Run("Unsupported model without image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "", // No image file
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger when no image file is specified")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithUserParameters(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *common.ChatOptions
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "All parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not applicable for PNG
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1536x1024",
|
||||
"quality": "high",
|
||||
"background": "transparent",
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JPEG with compression",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 75,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1024x1024",
|
||||
"quality": "medium",
|
||||
"background": "opaque",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": int64(75),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Only some parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.webp",
|
||||
ImageQuality: "low",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"quality": "low",
|
||||
"output_format": "webp",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No parameters specified (defaults)",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tools := client.addImageGenerationTool(tt.opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1)
|
||||
assert.NotNil(t, tools[0].OfImageGeneration)
|
||||
|
||||
tool := tools[0].OfImageGeneration
|
||||
|
||||
// Check required fields
|
||||
assert.Equal(t, "gpt-image-1", tool.Model)
|
||||
assert.Equal(t, tt.expected["output_format"], tool.OutputFormat)
|
||||
|
||||
// Check optional fields
|
||||
if expectedSize, ok := tt.expected["size"]; ok {
|
||||
assert.Equal(t, expectedSize, tool.Size)
|
||||
} else {
|
||||
assert.Empty(t, tool.Size, "Size should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedQuality, ok := tt.expected["quality"]; ok {
|
||||
assert.Equal(t, expectedQuality, tool.Quality)
|
||||
} else {
|
||||
assert.Empty(t, tool.Quality, "Quality should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedBackground, ok := tt.expected["background"]; ok {
|
||||
assert.Equal(t, expectedBackground, tool.Background)
|
||||
} else {
|
||||
assert.Empty(t, tool.Background, "Background should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedCompression, ok := tt.expected["output_compression"]; ok {
|
||||
assert.Equal(t, expectedCompression, tool.OutputCompression.Value)
|
||||
} else {
|
||||
assert.Equal(t, int64(0), tool.OutputCompression.Value, "Compression should not be set when not specified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,17 +57,18 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -80,17 +81,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -99,17 +101,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.228"
|
||||
var version = "v1.4.231"
|
||||
|
||||
Reference in New Issue
Block a user