mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb3f8ed43d | ||
|
|
4c1803cb6d | ||
|
|
d1c614d44e | ||
|
|
dbaa0b9754 | ||
|
|
4cfe2375ab | ||
|
|
2b371b69c7 | ||
|
|
6222a613e4 | ||
|
|
0882c43532 | ||
|
|
f0e1a1b77f | ||
|
|
a774f991ab | ||
|
|
a40bacaf34 | ||
|
|
969b85380c | ||
|
|
e8fe4434db | ||
|
|
7c7ceca264 | ||
|
|
c19d7ccd9d | ||
|
|
bd0c5f730e | ||
|
|
5900dac58f | ||
|
|
237219c3cc | ||
|
|
26fd700098 | ||
|
|
6bd926dd0f | ||
|
|
16ac519415 | ||
|
|
a32cc5fa01 | ||
|
|
26b5bb2e9e | ||
|
|
b751d323b1 | ||
|
|
d081fd269c | ||
|
|
369a0a850d | ||
|
|
8dc5343ee6 | ||
|
|
eda552dac5 | ||
|
|
f13a56685b | ||
|
|
2f9afe0247 | ||
|
|
1ec525ad97 | ||
|
|
b7dc6748e0 | ||
|
|
f1b612d828 | ||
|
|
eac5a104f2 | ||
|
|
4bff88fae3 | ||
|
|
acf1be71ce | ||
|
|
236a3c5f38 | ||
|
|
b2418984f8 | ||
|
|
152d74d160 | ||
|
|
4e16bbccd8 | ||
|
|
60174f41a4 | ||
|
|
ad4683952e | ||
|
|
86a044735b | ||
|
|
58583114cb | ||
|
|
36524cd2e4 | ||
|
|
e59156ac2b | ||
|
|
1eac026e92 |
59
.github/ISSUE_TEMPLATE/bug.yml
vendored
59
.github/ISSUE_TEMPLATE/bug.yml
vendored
@@ -7,29 +7,74 @@ body:
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
Please provide as much detail as possible to help us understand and reproduce the issue.
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Also tell us, what did you expect to happen?
|
||||
placeholder: Tell us what you see!
|
||||
value: "I was doing THIS, when THAT happened. I was expecting THAT_OTHER_THING to happen instead."
|
||||
value: "Please provide all the steps to reproduce the bug. I was doing THIS, when THAT happened. I was expecting THAT_OTHER_THING to happen instead."
|
||||
validations:
|
||||
required: true
|
||||
- type: checkboxes
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
options:
|
||||
- macOS - Silicon (arm64)
|
||||
- macOS - Intel (amd64)
|
||||
- Linux - amd64
|
||||
- Linux - arm64
|
||||
- Windows
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: os-version
|
||||
attributes:
|
||||
label: OS Version
|
||||
description: Please provide details about your OS version by running one of the following commands.
|
||||
placeholder: |
|
||||
macOS: `sw_vers`
|
||||
Linux: `uname -a` or `cat /etc/os-release`
|
||||
Windows: `ver`
|
||||
render: shell
|
||||
|
||||
- type: dropdown
|
||||
id: installation
|
||||
attributes:
|
||||
label: How did you install Fabric?
|
||||
description: "Please select the method you used to install Fabric. You can find this information in the [Installation section of the README](https://github.com/ksylvan/fabric/blob/main/README.md#installation)."
|
||||
options:
|
||||
- Release Binary - Windows
|
||||
- Release Binary - macOS (arm64)
|
||||
- Release Binary - macOS (amd64)
|
||||
- Release Binary - Linux (amd64)
|
||||
- Release Binary - Linux (arm64)
|
||||
- Package Manager - Homebrew (macOS)
|
||||
- Package Manager - AUR (Arch Linux)
|
||||
- From Source
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
label: Version check
|
||||
description: Please make sure you were using the latest version of this project available in the `main` branch.
|
||||
options:
|
||||
- label: Yes I was.
|
||||
required: true
|
||||
label: Version
|
||||
description: Please copy and paste the output of `fabric --version` (or `fabric-ai --version` if you installed it via brew) here.
|
||||
render: text
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: screens
|
||||
attributes:
|
||||
|
||||
213
README.md
213
README.md
@@ -93,6 +93,9 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
- [Just use the Patterns](#just-use-the-patterns)
|
||||
- [Prompt Strategies](#prompt-strategies)
|
||||
- [Custom Patterns](#custom-patterns)
|
||||
- [Setting Up Custom Patterns](#setting-up-custom-patterns)
|
||||
- [Using Custom Patterns](#using-custom-patterns)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Helper Apps](#helper-apps)
|
||||
- [`to_pdf`](#to_pdf)
|
||||
- [`to_pdf` Installation](#to_pdf-installation)
|
||||
@@ -114,13 +117,11 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
>
|
||||
> July 4, 2025
|
||||
>
|
||||
> - Fabric now supports web search using the `--search` and `--search-location` flags
|
||||
> - Web search is available for both Anthropic and OpenAI providers
|
||||
> - Previous plugin-level search configurations have been removed in favor of the new flag-based approach.
|
||||
> - If you used the previous approach, consider cleaning up your `~/.config/fabric/.env` file, removing the unused `ANTHROPIC_WEB_SEARCH_TOOL_ENABLED` and `ANTHROPIC_WEB_SEARCH_TOOL_LOCATION` variables.
|
||||
> - Fabric now supports image generation using the `--image-file` flag with OpenAI models
|
||||
> - Image generation works with both text prompts and input images (via `--attachment`) for image editing tasks
|
||||
>
|
||||
> - **Web Search**: Fabric now supports web search for Anthropic and OpenAI models using the `--search` and `--search-location` flags. This replaces the previous plugin-based search, so you may want to remove the old `ANTHROPIC_WEB_SEARCH_TOOL_*` variables from your `~/.config/fabric/.env` file.
|
||||
> - **Image Generation**: Fabric now has powerful image generation capabilities with OpenAI.
|
||||
> - Generate images from text prompts and save them using `--image-file`.
|
||||
> - Edit existing images by providing an input image with `--attachment`.
|
||||
> - Control image `size`, `quality`, `compression`, and `background` with the new `--image-*` flags.
|
||||
>
|
||||
>June 17, 2025
|
||||
>
|
||||
@@ -292,88 +293,88 @@ yt() {
|
||||
|
||||
You can add the below code for the equivalent aliases inside PowerShell by running `notepad $PROFILE` inside a PowerShell window:
|
||||
|
||||
```powershell
|
||||
# Path to the patterns directory
|
||||
$patternsPath = Join-Path $HOME ".config/fabric/patterns"
|
||||
foreach ($patternDir in Get-ChildItem -Path $patternsPath -Directory) {
|
||||
$patternName = $patternDir.Name
|
||||
```powershell
|
||||
# Path to the patterns directory
|
||||
$patternsPath = Join-Path $HOME ".config/fabric/patterns"
|
||||
foreach ($patternDir in Get-ChildItem -Path $patternsPath -Directory) {
|
||||
$patternName = $patternDir.Name
|
||||
|
||||
# Dynamically define a function for each pattern
|
||||
$functionDefinition = @"
|
||||
function $patternName {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter(ValueFromPipeline = `$true)]
|
||||
[string] `$InputObject,
|
||||
# Dynamically define a function for each pattern
|
||||
$functionDefinition = @"
|
||||
function $patternName {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter(ValueFromPipeline = `$true)]
|
||||
[string] `$InputObject,
|
||||
|
||||
[Parameter(ValueFromRemainingArguments = `$true)]
|
||||
[String[]] `$patternArgs
|
||||
)
|
||||
[Parameter(ValueFromRemainingArguments = `$true)]
|
||||
[String[]] `$patternArgs
|
||||
)
|
||||
|
||||
begin {
|
||||
# Initialize an array to collect pipeline input
|
||||
`$collector = @()
|
||||
}
|
||||
|
||||
process {
|
||||
# Collect pipeline input objects
|
||||
if (`$InputObject) {
|
||||
`$collector += `$InputObject
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
# Join all pipeline input into a single string, separated by newlines
|
||||
`$pipelineContent = `$collector -join "`n"
|
||||
|
||||
# If there's pipeline input, include it in the call to fabric
|
||||
if (`$pipelineContent) {
|
||||
`$pipelineContent | fabric --pattern $patternName `$patternArgs
|
||||
} else {
|
||||
# No pipeline input; just call fabric with the additional args
|
||||
fabric --pattern $patternName `$patternArgs
|
||||
}
|
||||
}
|
||||
}
|
||||
"@
|
||||
# Add the function to the current session
|
||||
Invoke-Expression $functionDefinition
|
||||
begin {
|
||||
# Initialize an array to collect pipeline input
|
||||
`$collector = @()
|
||||
}
|
||||
|
||||
# Define the 'yt' function as well
|
||||
function yt {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter()]
|
||||
[Alias("timestamps")]
|
||||
[switch]$t,
|
||||
|
||||
[Parameter(Position = 0, ValueFromPipeline = $true)]
|
||||
[string]$videoLink
|
||||
)
|
||||
|
||||
begin {
|
||||
$transcriptFlag = "--transcript"
|
||||
if ($t) {
|
||||
$transcriptFlag = "--transcript-with-timestamps"
|
||||
}
|
||||
}
|
||||
|
||||
process {
|
||||
if (-not $videoLink) {
|
||||
Write-Error "Usage: yt [-t | --timestamps] youtube-link"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
if ($videoLink) {
|
||||
# Execute and allow output to flow through the pipeline
|
||||
fabric -y $videoLink $transcriptFlag
|
||||
}
|
||||
process {
|
||||
# Collect pipeline input objects
|
||||
if (`$InputObject) {
|
||||
`$collector += `$InputObject
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
end {
|
||||
# Join all pipeline input into a single string, separated by newlines
|
||||
`$pipelineContent = `$collector -join "`n"
|
||||
|
||||
# If there's pipeline input, include it in the call to fabric
|
||||
if (`$pipelineContent) {
|
||||
`$pipelineContent | fabric --pattern $patternName `$patternArgs
|
||||
} else {
|
||||
# No pipeline input; just call fabric with the additional args
|
||||
fabric --pattern $patternName `$patternArgs
|
||||
}
|
||||
}
|
||||
}
|
||||
"@
|
||||
# Add the function to the current session
|
||||
Invoke-Expression $functionDefinition
|
||||
}
|
||||
|
||||
# Define the 'yt' function as well
|
||||
function yt {
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[Parameter()]
|
||||
[Alias("timestamps")]
|
||||
[switch]$t,
|
||||
|
||||
[Parameter(Position = 0, ValueFromPipeline = $true)]
|
||||
[string]$videoLink
|
||||
)
|
||||
|
||||
begin {
|
||||
$transcriptFlag = "--transcript"
|
||||
if ($t) {
|
||||
$transcriptFlag = "--transcript-with-timestamps"
|
||||
}
|
||||
}
|
||||
|
||||
process {
|
||||
if (-not $videoLink) {
|
||||
Write-Error "Usage: yt [-t | --timestamps] youtube-link"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
end {
|
||||
if ($videoLink) {
|
||||
# Execute and allow output to flow through the pipeline
|
||||
fabric -y $videoLink $transcriptFlag
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This also creates a `yt` alias that allows you to use `yt https://www.youtube.com/watch?v=4b0iet22VIk` to get transcripts, comments, and metadata.
|
||||
|
||||
@@ -559,6 +560,11 @@ Application Options:
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI)
|
||||
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
|
||||
--image-file= Save generated image to specified file path (e.g., 'output.png')
|
||||
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)
|
||||
--image-quality= Image quality: low, medium, high, auto (default: auto)
|
||||
--image-compression= Compression level 0-100 for JPEG/WebP formats (default: not set)
|
||||
--image-background= Background type: opaque, transparent (default: opaque, only for
|
||||
PNG/WebP)
|
||||
|
||||
Help Options:
|
||||
-h, --help Show this help message
|
||||
@@ -649,11 +655,48 @@ Use `fabric -S` and select the option to install the strategies in your `~/.conf
|
||||
|
||||
You may want to use Fabric to create your own custom Patterns—but not share them with others. No problem!
|
||||
|
||||
Just make a directory in `~/.config/custompatterns/` (or wherever) and put your `.md` files in there.
|
||||
Fabric now supports a dedicated custom patterns directory that keeps your personal patterns separate from the built-in ones. This means your custom patterns won't be overwritten when you update Fabric's built-in patterns.
|
||||
|
||||
When you're ready to use them, copy them into `~/.config/fabric/patterns/`
|
||||
### Setting Up Custom Patterns
|
||||
|
||||
You can then use them like any other Patterns, but they won't be public unless you explicitly submit them as Pull Requests to the Fabric project. So don't worry—they're private to you.
|
||||
1. Run the Fabric setup:
|
||||
|
||||
```bash
|
||||
fabric --setup
|
||||
```
|
||||
|
||||
2. Select the "Custom Patterns" option from the Tools menu and enter your desired directory path (e.g., `~/my-custom-patterns`)
|
||||
|
||||
3. Fabric will automatically create the directory if it does not exist.
|
||||
|
||||
### Using Custom Patterns
|
||||
|
||||
1. Create your custom pattern directory structure:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/my-custom-patterns/my-analyzer
|
||||
```
|
||||
|
||||
2. Create your pattern file
|
||||
|
||||
```bash
|
||||
echo "You are an expert analyzer of ..." > ~/my-custom-patterns/my-analyzer/system.md
|
||||
```
|
||||
|
||||
3. **Use your custom pattern:**
|
||||
|
||||
```bash
|
||||
fabric --pattern my-analyzer "analyze this text"
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
- **Priority System**: Custom patterns take precedence over built-in patterns with the same name
|
||||
- **Seamless Integration**: Custom patterns appear in `fabric --listpatterns` alongside built-in ones
|
||||
- **Update Safe**: Your custom patterns are never affected by `fabric --updatepatterns`
|
||||
- **Private by Default**: Custom patterns remain private unless you explicitly share them
|
||||
|
||||
Your custom patterns are completely private and won't be affected by Fabric updates!
|
||||
|
||||
## Helper Apps
|
||||
|
||||
|
||||
@@ -270,7 +270,11 @@ func Cli(version string) (err error) {
|
||||
if chatReq.Language == "" {
|
||||
chatReq.Language = registry.Language.DefaultLanguage.Value
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, currentFlags.BuildChatOptions()); err != nil {
|
||||
var chatOptions *common.ChatOptions
|
||||
if chatOptions, err = currentFlags.BuildChatOptions(); err != nil {
|
||||
return
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, chatOptions); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
126
cli/flags.go
126
cli/flags.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/jessevdk/go-flags"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
|
||||
@@ -77,6 +78,10 @@ type Flags struct {
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
|
||||
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`
|
||||
ImageQuality string `long:"image-quality" description:"Image quality: low, medium, high, auto (default: auto)"`
|
||||
ImageCompression int `long:"image-compression" description:"Compression level 0-100 for JPEG/WebP formats (default: not set)"`
|
||||
ImageBackground string `long:"image-background" description:"Background type: opaque, transparent (default: opaque, only for PNG/WebP)"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
@@ -257,8 +262,121 @@ func readStdin() (ret string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
// validateImageFile validates the image file path and extension
|
||||
func validateImageFile(imagePath string) error {
|
||||
if imagePath == "" {
|
||||
return nil // No validation needed if no image file specified
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(imagePath); err == nil {
|
||||
return fmt.Errorf("image file already exists: %s", imagePath)
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
|
||||
for _, validExt := range validExtensions {
|
||||
if ext == validExt {
|
||||
return nil // Valid extension found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
|
||||
}
|
||||
|
||||
// validateImageParameters validates image generation parameters
|
||||
func validateImageParameters(imagePath, size, quality, background string, compression int) error {
|
||||
if imagePath == "" {
|
||||
// Check if any image parameters are specified without --image-file
|
||||
if size != "" || quality != "" || background != "" || compression != 0 {
|
||||
return fmt.Errorf("image parameters (--image-size, --image-quality, --image-background, --image-compression) can only be used with --image-file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if size != "" {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
valid := false
|
||||
for _, validSize := range validSizes {
|
||||
if size == validSize {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image size '%s'. Supported sizes: 1024x1024, 1536x1024, 1024x1536, auto", size)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate quality
|
||||
if quality != "" {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
valid := false
|
||||
for _, validQuality := range validQualities {
|
||||
if quality == validQuality {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image quality '%s'. Supported qualities: low, medium, high, auto", quality)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background
|
||||
if background != "" {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
valid := false
|
||||
for _, validBackground := range validBackgrounds {
|
||||
if background == validBackground {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image background '%s'. Supported backgrounds: opaque, transparent", background)
|
||||
}
|
||||
}
|
||||
|
||||
// Get file format for format-specific validations
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
|
||||
// Validate compression (only for jpeg/webp)
|
||||
if compression != 0 { // 0 means not set
|
||||
if ext != ".jpg" && ext != ".jpeg" && ext != ".webp" {
|
||||
return fmt.Errorf("image compression can only be used with JPEG and WebP formats, not %s", ext)
|
||||
}
|
||||
if compression < 0 || compression > 100 {
|
||||
return fmt.Errorf("image compression must be between 0 and 100, got %d", compression)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background transparency (only for png/webp)
|
||||
if background == "transparent" {
|
||||
if ext != ".png" && ext != ".webp" {
|
||||
return fmt.Errorf("transparent background can only be used with PNG and WebP formats, not %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
// Validate image file if specified
|
||||
if err = validateImageFile(o.ImageFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate image parameters
|
||||
if err = validateImageParameters(o.ImageFile, o.ImageSize, o.ImageQuality, o.ImageBackground, o.ImageCompression); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = &common.ChatOptions{
|
||||
Model: o.Model,
|
||||
Temperature: o.Temperature,
|
||||
TopP: o.TopP,
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
@@ -269,6 +387,10 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
Search: o.Search,
|
||||
SearchLocation: o.SearchLocation,
|
||||
ImageFile: o.ImageFile,
|
||||
ImageSize: o.ImageSize,
|
||||
ImageQuality: o.ImageQuality,
|
||||
ImageCompression: o.ImageCompression,
|
||||
ImageBackground: o.ImageBackground,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -64,7 +65,8 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -84,7 +86,8 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -164,3 +167,269 @@ model: 123 # should be string
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageFile(t *testing.T) {
|
||||
t.Run("Empty path should be valid", func(t *testing.T) {
|
||||
err := validateImageFile("")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid extensions should pass", func(t *testing.T) {
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
for _, ext := range validExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.NoError(t, err, "Extension %s should be valid", ext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid extensions should fail", func(t *testing.T) {
|
||||
invalidExtensions := []string{".gif", ".bmp", ".tiff", ".svg", ".txt", ""}
|
||||
for _, ext := range invalidExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.Error(t, err, "Extension %s should be invalid", ext)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "test*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
// Validation should fail because file exists
|
||||
err = validateImageFile(tempFile.Name())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
|
||||
t.Run("Non-existing file with valid extension should pass", func(t *testing.T) {
|
||||
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.png")
|
||||
// Make sure the file doesn't exist
|
||||
os.Remove(nonExistentFile)
|
||||
|
||||
err := validateImageFile(nonExistentFile)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
|
||||
t.Run("Valid image file should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/tmp/output.png", options.ImageFile)
|
||||
})
|
||||
|
||||
t.Run("Invalid extension should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.gif",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "existing*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
flags := &Flags{
|
||||
ImageFile: tempFile.Name(),
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageParameters(t *testing.T) {
|
||||
t.Run("No image file and no parameters should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("", "", "", "", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail", func(t *testing.T) {
|
||||
// Test each parameter individually
|
||||
err := validateImageParameters("", "1024x1024", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
|
||||
err = validateImageParameters("", "", "high", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
// Test multiple parameters
|
||||
err = validateImageParameters("", "1024x1024", "high", "transparent", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
})
|
||||
|
||||
t.Run("Valid size values should pass", func(t *testing.T) {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
for _, size := range validSizes {
|
||||
err := validateImageParameters("/tmp/test.png", size, "", "", 0)
|
||||
assert.NoError(t, err, "Size %s should be valid", size)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid size should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "invalid", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("Valid quality values should pass", func(t *testing.T) {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
for _, quality := range validQualities {
|
||||
err := validateImageParameters("/tmp/test.png", "", quality, "", 0)
|
||||
assert.NoError(t, err, "Quality %s should be valid", quality)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid quality should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "invalid", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image quality")
|
||||
})
|
||||
|
||||
t.Run("Valid background values should pass", func(t *testing.T) {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
for _, background := range validBackgrounds {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", background, 0)
|
||||
assert.NoError(t, err, "Background %s should be valid", background)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid background should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "invalid", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image background")
|
||||
})
|
||||
|
||||
t.Run("Compression for JPEG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 75)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "", 50)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for PNG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "", 75)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression can only be used with JPEG and WebP formats")
|
||||
})
|
||||
|
||||
t.Run("Invalid compression range should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 150)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
|
||||
err = validateImageParameters("/tmp/test.jpg", "", "", "", -10)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
})
|
||||
|
||||
t.Run("Transparent background for PNG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for JPEG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transparent background can only be used with PNG and WebP formats")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageParameters(t *testing.T) {
|
||||
t.Run("Valid image parameters should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not set for PNG
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, "/tmp/test.png", options.ImageFile)
|
||||
assert.Equal(t, "1024x1024", options.ImageSize)
|
||||
assert.Equal(t, "high", options.ImageQuality)
|
||||
assert.Equal(t, "transparent", options.ImageBackground)
|
||||
assert.Equal(t, 0, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Invalid image parameters should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "invalid",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("JPEG with compression should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 80,
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, 80, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail in BuildChatOptions", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageSize: "1024x1024", // Image parameter without ImageFile
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ type ChatOptions struct {
|
||||
Search bool
|
||||
SearchLocation string
|
||||
ImageFile string
|
||||
ImageSize string
|
||||
ImageQuality string
|
||||
ImageCompression int
|
||||
ImageBackground string
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
||||
124
common/oauth_storage.go
Normal file
124
common/oauth_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthToken represents stored OAuth token information
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the token is expired or will expire within the buffer time
|
||||
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
|
||||
if t.ExpiresAt == 0 {
|
||||
return true
|
||||
}
|
||||
bufferTime := time.Duration(bufferMinutes) * time.Minute
|
||||
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
|
||||
}
|
||||
|
||||
// OAuthStorage handles persistent storage of OAuth tokens
|
||||
type OAuthStorage struct {
|
||||
configDir string
|
||||
}
|
||||
|
||||
// NewOAuthStorage creates a new OAuth storage instance
|
||||
func NewOAuthStorage() (*OAuthStorage, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".config", "fabric")
|
||||
|
||||
// Ensure config directory exists
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return &OAuthStorage{configDir: configDir}, nil
|
||||
}
|
||||
|
||||
// GetTokenPath returns the file path for a provider's OAuth token
|
||||
func (s *OAuthStorage) GetTokenPath(provider string) string {
|
||||
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
|
||||
}
|
||||
|
||||
// SaveToken saves an OAuth token to disk with proper permissions
|
||||
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Marshal token to JSON
|
||||
data, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first for atomic operation
|
||||
tempPath := tokenPath + ".tmp"
|
||||
if err := os.WriteFile(tempPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, tokenPath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to save token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToken loads an OAuth token from disk
|
||||
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
return nil, nil // No token stored
|
||||
}
|
||||
|
||||
// Read token file
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal token
|
||||
var token OAuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken removes a stored OAuth token
|
||||
func (s *OAuthStorage) DeleteToken(provider string) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasValidToken checks if a valid (non-expired) token exists for a provider
|
||||
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
|
||||
token, err := s.LoadToken(provider)
|
||||
if err != nil || token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !token.IsExpired(bufferMinutes)
|
||||
}
|
||||
232
common/oauth_storage_test.go
Normal file
232
common/oauth_storage_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOAuthToken_IsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt int64
|
||||
bufferMinutes int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "token not expired",
|
||||
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
bufferMinutes: 5,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "token expired",
|
||||
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "token expires within buffer",
|
||||
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry time",
|
||||
expiresAt: 0,
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &OAuthToken{ExpiresAt: tt.expiresAt}
|
||||
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create storage with custom config dir
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_access_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "test_scope",
|
||||
}
|
||||
|
||||
// Test saving token
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and has correct permissions
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
info, err := os.Stat(tokenPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Token file not created: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
|
||||
}
|
||||
|
||||
// Test loading token
|
||||
loadedToken, err := storage.LoadToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load token: %v", err)
|
||||
}
|
||||
if loadedToken == nil {
|
||||
t.Fatal("Loaded token is nil")
|
||||
}
|
||||
|
||||
// Verify token data
|
||||
if loadedToken.AccessToken != token.AccessToken {
|
||||
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
if loadedToken.RefreshToken != token.RefreshToken {
|
||||
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
|
||||
}
|
||||
if loadedToken.ExpiresAt != token.ExpiresAt {
|
||||
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Try to load non-existent token
|
||||
token, err := storage.LoadToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error loading non-existent token: %v", err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Error("Expected nil token for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_DeleteToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Create and save a token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
}
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token exists
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
t.Fatal("Token file should exist before deletion")
|
||||
}
|
||||
|
||||
// Delete token
|
||||
err = storage.DeleteToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token is deleted
|
||||
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
|
||||
t.Error("Token file should not exist after deletion")
|
||||
}
|
||||
|
||||
// Test deleting non-existent token (should not error)
|
||||
err = storage.DeleteToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Deleting non-existent token should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_HasValidToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test with no token
|
||||
if storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return false when no token exists")
|
||||
}
|
||||
|
||||
// Save valid token
|
||||
validToken := &OAuthToken{
|
||||
AccessToken: "valid_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
}
|
||||
err = storage.SaveToken("test_provider", validToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save valid token: %v", err)
|
||||
}
|
||||
|
||||
// Test with valid token
|
||||
if !storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return true for valid token")
|
||||
}
|
||||
|
||||
// Save expired token
|
||||
expiredToken := &OAuthToken{
|
||||
AccessToken: "expired_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
}
|
||||
err = storage.SaveToken("expired_provider", expiredToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save expired token: %v", err)
|
||||
}
|
||||
|
||||
// Test with expired token
|
||||
if storage.HasValidToken("expired_provider", 5) {
|
||||
t.Error("Should return false for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_GetTokenPath(t *testing.T) {
|
||||
storage := &OAuthStorage{configDir: "/test/config"}
|
||||
|
||||
expected := filepath.Join("/test/config", ".test_provider_oauth")
|
||||
actual := storage.GetTokenPath("test_provider")
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
@@ -98,7 +98,11 @@ _fabric() {
|
||||
'(--version)--version[Print current version]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.jpg *.jpeg *.gif *.bmp"' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \
|
||||
'(--image-quality)--image-quality[Image quality]:quality:(low medium high auto)' \
|
||||
'(--image-compression)--image-compression[Compression level 0-100 for JPEG/WebP formats]:compression:' \
|
||||
'(--image-background)--image-background[Background type]:background:(opaque transparent)' \
|
||||
'(--listextensions)--listextensions[List all registered extensions]' \
|
||||
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \
|
||||
|
||||
@@ -13,7 +13,7 @@ _fabric() {
|
||||
_get_comp_words_by_ref -n : cur prev words cword
|
||||
|
||||
# Define all possible options/flags
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --image-size --image-quality --image-compression --image-background --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
|
||||
# Helper function for dynamic completions
|
||||
_fabric_get_list() {
|
||||
@@ -67,8 +67,21 @@ _fabric() {
|
||||
_filedir
|
||||
return 0
|
||||
;;
|
||||
# Image generation options with specific values
|
||||
--image-size)
|
||||
COMPREPLY=($(compgen -W "1024x1024 1536x1024 1024x1536 auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-quality)
|
||||
COMPREPLY=($(compgen -W "low medium high auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-background)
|
||||
COMPREPLY=($(compgen -W "opaque transparent" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
# Options requiring simple arguments (no specific completion logic here)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location | --image-compression)
|
||||
# No specific completion suggestions, user types the value
|
||||
return 0
|
||||
;;
|
||||
|
||||
@@ -61,7 +61,11 @@ complete -c fabric -l address -d "The address to bind the REST API (default: :80
|
||||
complete -c fabric -l api-key -d "API key used to secure server routes"
|
||||
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.jpg *.jpeg *.gif *.bmp"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
|
||||
complete -c fabric -l image-size -d "Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)" -a "1024x1024 1536x1024 1024x1536 auto"
|
||||
complete -c fabric -l image-quality -d "Image quality: low, medium, high, auto (default: auto)" -a "low medium high auto"
|
||||
complete -c fabric -l image-compression -d "Compression level 0-100 for JPEG/WebP formats (default: not set)" -r
|
||||
complete -c fabric -l image-background -d "Background type: opaque, transparent (default: opaque, only for PNG/WebP)" -a "opaque transparent"
|
||||
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
|
||||
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins/db/fsdb"
|
||||
"github.com/danielmiessler/fabric/plugins/template"
|
||||
"github.com/danielmiessler/fabric/plugins/tools"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/custom_patterns"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/jina"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/lang"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/youtube"
|
||||
@@ -69,6 +70,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
|
||||
VendorManager: ai.NewVendorsManager(),
|
||||
VendorsAll: ai.NewVendorsManager(),
|
||||
PatternsLoader: tools.NewPatternsLoader(db.Patterns),
|
||||
CustomPatterns: custom_patterns.NewCustomPatterns(),
|
||||
YouTube: youtube.NewYouTube(),
|
||||
Language: lang.NewLanguage(),
|
||||
Jina: jina.NewClient(),
|
||||
@@ -138,6 +140,7 @@ type PluginRegistry struct {
|
||||
VendorsAll *ai.VendorsManager
|
||||
Defaults *tools.Defaults
|
||||
PatternsLoader *tools.PatternsLoader
|
||||
CustomPatterns *custom_patterns.CustomPatterns
|
||||
YouTube *youtube.YouTube
|
||||
Language *lang.Language
|
||||
Jina *jina.Client
|
||||
@@ -151,6 +154,7 @@ func (o *PluginRegistry) SaveEnvFile() (err error) {
|
||||
|
||||
o.Defaults.Settings.FillEnvFileContent(&envFileContent)
|
||||
o.PatternsLoader.SetupFillEnvFileContent(&envFileContent)
|
||||
o.CustomPatterns.SetupFillEnvFileContent(&envFileContent)
|
||||
o.Strategies.SetupFillEnvFileContent(&envFileContent)
|
||||
|
||||
for _, vendor := range o.VendorManager.Vendors {
|
||||
@@ -183,7 +187,7 @@ func (o *PluginRegistry) Setup() (err error) {
|
||||
return vendor
|
||||
})...)
|
||||
|
||||
groupsPlugins.AddGroupItems("Tools", o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
groupsPlugins.AddGroupItems("Tools", o.CustomPatterns, o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
|
||||
for {
|
||||
groupsPlugins.Print(false)
|
||||
@@ -239,7 +243,7 @@ func (o *PluginRegistry) SetupVendor(vendorName string) (err error) {
|
||||
func (o *PluginRegistry) ConfigureVendors() {
|
||||
o.VendorManager.Clear()
|
||||
for _, vendor := range o.VendorsAll.Vendors {
|
||||
if vendorErr := vendor.Configure(); vendorErr == nil {
|
||||
if vendorErr := vendor.Configure(); vendorErr == nil && vendor.IsConfigured() {
|
||||
o.VendorManager.AddVendors(vendor)
|
||||
}
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@@ -25,9 +25,9 @@ require (
|
||||
github.com/samber/lo v1.50.0
|
||||
github.com/sgaunet/perplexity-go/v2 v2.8.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -109,7 +109,6 @@ require (
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
|
||||
1
go.sum
1
go.sum
@@ -354,7 +354,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -325,9 +325,6 @@ schema = 3
|
||||
[mod."gopkg.in/warnings.v0"]
|
||||
version = "v0.1.2"
|
||||
hash = "sha256-ATVL9yEmgYbkJ1DkltDGRn/auGAjqGOfjQyBYyUo8s8="
|
||||
[mod."gopkg.in/yaml.v2"]
|
||||
version = "v2.4.0"
|
||||
hash = "sha256-uVEGglIedjOIGZzHW4YwN1VoRSTK8o0eGZqzd+TNdd0="
|
||||
[mod."gopkg.in/yaml.v3"]
|
||||
version = "v3.0.1"
|
||||
hash = "sha256-FqL9TKYJ0XkNwJFnq9j0VvJ5ZUU1RvH/52h/f5bkYAU="
|
||||
|
||||
@@ -1 +1 @@
|
||||
"1.4.227"
|
||||
"1.4.238"
|
||||
|
||||
@@ -25,6 +25,7 @@ Your goal is to output a JSON object called tags, with the following tags applie
|
||||
- **productivity** - Efficiency, time management, workflows
|
||||
- **writing** - Writing craft, process, tips
|
||||
- **creativity** - Creative process, artistic expression
|
||||
- **tutorial** - Technical or non-technical guides, how-tos
|
||||
|
||||
# STEPS
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
@@ -18,6 +19,8 @@ const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
const vendorTokenIdentifier = "claude"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
ret = &Client{}
|
||||
@@ -30,7 +33,8 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -46,10 +50,43 @@ func NewClient() (ret *Client) {
|
||||
return
|
||||
}
|
||||
|
||||
// IsConfigured returns true if either the API key or OAuth is configured
|
||||
func (an *Client) IsConfigured() bool {
|
||||
// Check if API key is configured
|
||||
if an.ApiKey.Value != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if OAuth is enabled and has a valid token
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// If no valid token exists, automatically run OAuth flow
|
||||
if !storage.HasValidToken(vendorTokenIdentifier, 5) {
|
||||
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
|
||||
_, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
fmt.Printf("OAuth authentication failed: %v\n", err)
|
||||
return false
|
||||
}
|
||||
// After successful OAuth flow, check again
|
||||
return storage.HasValidToken("claude", 5)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -58,24 +95,50 @@ type Client struct {
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
baseURL := an.ApiBaseURL.Value
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// As of 2.0beta1, using v2 API endpoint.
|
||||
// https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md#020-beta1-2025-03-25
|
||||
if strings.Contains(baseURL, "-") && !strings.HasSuffix(baseURL, "/v2") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
baseURL = baseURL + "/v2"
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(
|
||||
option.WithAPIKey(an.ApiKey.Value),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
} else {
|
||||
an.client = anthropic.NewClient(option.WithAPIKey(an.ApiKey.Value))
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
opts := []option.RequestOption{}
|
||||
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,6 +187,17 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
@@ -207,6 +281,9 @@ func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropi
|
||||
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
var systemContent string
|
||||
|
||||
// Note: Claude Code spoofing is now handled in buildMessageParams
|
||||
|
||||
isFirstUserMessage := true
|
||||
lastRoleWasUser := false
|
||||
|
||||
|
||||
300
plugins/ai/anthropic/oauth.go
Normal file
300
plugins/ai/anthropic/oauth.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -69,7 +69,9 @@ func (o *Client) buildChatCompletionParams(
|
||||
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
if opts.TopP != 0 {
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
}
|
||||
if opts.MaxTokens != 0 {
|
||||
ret.MaxTokens = openai.Int(int64(opts.MaxTokens))
|
||||
}
|
||||
|
||||
@@ -128,6 +128,11 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
func (o *Client) sendResponses(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
// Validate model supports image generation if image file is specified
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
return "", fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
}
|
||||
|
||||
req := o.buildResponseParams(msgs, opts)
|
||||
|
||||
var resp *responses.Response
|
||||
@@ -216,7 +221,9 @@ func (o *Client) buildResponseParams(
|
||||
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
if opts.TopP != 0 {
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
}
|
||||
if opts.MaxTokens != 0 {
|
||||
ret.MaxOutputTokens = openai.Int(int64(opts.MaxTokens))
|
||||
}
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/responses"
|
||||
)
|
||||
|
||||
@@ -17,19 +19,82 @@ import (
|
||||
const ImageGenerationResponseType = "image_generation_call"
|
||||
const ImageGenerationToolType = "image_generation"
|
||||
|
||||
// ImageGenerationSupportedModels lists all models that support image generation
|
||||
var ImageGenerationSupportedModels = []string{
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"o3",
|
||||
}
|
||||
|
||||
// supportsImageGeneration checks if the given model supports the image_generation tool
|
||||
func supportsImageGeneration(model string) bool {
|
||||
for _, supportedModel := range ImageGenerationSupportedModels {
|
||||
if model == supportedModel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getOutputFormatFromExtension determines the API output format based on file extension
|
||||
func getOutputFormatFromExtension(imagePath string) string {
|
||||
if imagePath == "" {
|
||||
return "png" // Default format
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
switch ext {
|
||||
case ".png":
|
||||
return "png"
|
||||
case ".webp":
|
||||
return "webp"
|
||||
case ".jpg":
|
||||
return "jpeg"
|
||||
case ".jpeg":
|
||||
return "jpeg"
|
||||
default:
|
||||
return "png" // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// addImageGenerationTool adds the image generation tool to the request if needed
|
||||
func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []responses.ToolUnionParam) []responses.ToolUnionParam {
|
||||
// Check if the request seems to be asking for image generation
|
||||
if o.shouldUseImageGeneration(opts) {
|
||||
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
|
||||
|
||||
// Build the image generation tool with user parameters
|
||||
imageGenTool := responses.ToolUnionParam{
|
||||
OfImageGeneration: &responses.ToolImageGenerationParam{
|
||||
Type: ImageGenerationToolType,
|
||||
Model: "gpt-image-1",
|
||||
OutputFormat: "png",
|
||||
Quality: "auto",
|
||||
Size: "auto",
|
||||
OutputFormat: outputFormat,
|
||||
},
|
||||
}
|
||||
|
||||
// Set quality if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageQuality != "" {
|
||||
imageGenTool.OfImageGeneration.Quality = opts.ImageQuality
|
||||
}
|
||||
|
||||
// Set size if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageSize != "" {
|
||||
imageGenTool.OfImageGeneration.Size = opts.ImageSize
|
||||
}
|
||||
|
||||
// Set background if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageBackground != "" {
|
||||
imageGenTool.OfImageGeneration.Background = opts.ImageBackground
|
||||
}
|
||||
|
||||
// Set compression if specified by user (only for jpeg/webp)
|
||||
if opts.ImageCompression != 0 {
|
||||
imageGenTool.OfImageGeneration.OutputCompression = param.NewOpt(int64(opts.ImageCompression))
|
||||
}
|
||||
|
||||
tools = append(tools, imageGenTool)
|
||||
}
|
||||
return tools
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
@@ -112,3 +114,331 @@ func TestBuildResponseParams_WithBothSearchAndImage(t *testing.T) {
|
||||
assert.True(t, hasSearchTool, "Should have web search tool")
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestGetOutputFormatFromExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
imagePath string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG extension",
|
||||
imagePath: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP extension",
|
||||
imagePath: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG extension",
|
||||
imagePath: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG extension",
|
||||
imagePath: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Uppercase PNG extension",
|
||||
imagePath: "/tmp/output.PNG",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Mixed case JPEG extension",
|
||||
imagePath: "/tmp/output.JpEg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
imagePath: "",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "No extension",
|
||||
imagePath: "/tmp/output",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Unsupported extension",
|
||||
imagePath: "/tmp/output.gif",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getOutputFormatFromExtension(tt.imagePath)
|
||||
assert.Equal(t, tt.expectedFormat, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
imageFile string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG file",
|
||||
imageFile: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP file",
|
||||
imageFile: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG file",
|
||||
imageFile: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG file",
|
||||
imageFile: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: tt.imageFile,
|
||||
}
|
||||
|
||||
tools := client.addImageGenerationTool(opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1, "Should have one tool")
|
||||
assert.NotNil(t, tools[0].OfImageGeneration, "Should be image generation tool")
|
||||
assert.Equal(t, tt.expectedFormat, tools[0].OfImageGeneration.OutputFormat, "Output format should match file extension")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsImageGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "gpt-4o supports image generation",
|
||||
model: "gpt-4o",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini supports image generation",
|
||||
model: "gpt-4o-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1 supports image generation",
|
||||
model: "gpt-4.1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-mini supports image generation",
|
||||
model: "gpt-4.1-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-nano supports image generation",
|
||||
model: "gpt-4.1-nano",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o3 supports image generation",
|
||||
model: "o3",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o1 does not support image generation",
|
||||
model: "o1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o1-mini does not support image generation",
|
||||
model: "o1-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o3-mini does not support image generation",
|
||||
model: "o3-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4 does not support image generation",
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-3.5-turbo does not support image generation",
|
||||
model: "gpt-3.5-turbo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model does not support image generation",
|
||||
model: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := supportsImageGeneration(tt.model)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelValidationLogic(t *testing.T) {
|
||||
t.Run("Unsupported model with image file should return validation error", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
err := fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
|
||||
assert.Contains(t, err.Error(), "does not support image generation")
|
||||
assert.Contains(t, err.Error(), "o1-mini")
|
||||
assert.Contains(t, err.Error(), "Supported models:")
|
||||
} else {
|
||||
t.Error("Expected validation to trigger")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Supported model with image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger for supported model")
|
||||
})
|
||||
|
||||
t.Run("Unsupported model without image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "", // No image file
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger when no image file is specified")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithUserParameters(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *common.ChatOptions
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "All parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not applicable for PNG
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1536x1024",
|
||||
"quality": "high",
|
||||
"background": "transparent",
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JPEG with compression",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 75,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1024x1024",
|
||||
"quality": "medium",
|
||||
"background": "opaque",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": int64(75),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Only some parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.webp",
|
||||
ImageQuality: "low",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"quality": "low",
|
||||
"output_format": "webp",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No parameters specified (defaults)",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tools := client.addImageGenerationTool(tt.opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1)
|
||||
assert.NotNil(t, tools[0].OfImageGeneration)
|
||||
|
||||
tool := tools[0].OfImageGeneration
|
||||
|
||||
// Check required fields
|
||||
assert.Equal(t, "gpt-image-1", tool.Model)
|
||||
assert.Equal(t, tt.expected["output_format"], tool.OutputFormat)
|
||||
|
||||
// Check optional fields
|
||||
if expectedSize, ok := tt.expected["size"]; ok {
|
||||
assert.Equal(t, expectedSize, tool.Size)
|
||||
} else {
|
||||
assert.Empty(t, tool.Size, "Size should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedQuality, ok := tt.expected["quality"]; ok {
|
||||
assert.Equal(t, expectedQuality, tool.Quality)
|
||||
} else {
|
||||
assert.Empty(t, tool.Quality, "Quality should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedBackground, ok := tt.expected["background"]; ok {
|
||||
assert.Equal(t, expectedBackground, tool.Background)
|
||||
} else {
|
||||
assert.Empty(t, tool.Background, "Background should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedCompression, ok := tt.expected["output_compression"]; ok {
|
||||
assert.Equal(t, expectedCompression, tool.OutputCompression.Value)
|
||||
} else {
|
||||
assert.Equal(t, int64(0), tool.OutputCompression.Value, "Compression should not be set when not specified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
@@ -19,6 +20,7 @@ func NewDb(dir string) (db *Db) {
|
||||
StorageEntity: &StorageEntity{Label: "Patterns", Dir: db.FilePath("patterns"), ItemIsDir: true},
|
||||
SystemPatternFile: "system.md",
|
||||
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
||||
CustomPatternsDir: "", // Will be set after loading .env file
|
||||
}
|
||||
|
||||
db.Sessions = &SessionsEntity{
|
||||
@@ -49,6 +51,18 @@ func (o *Db) Configure() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set custom patterns directory after loading .env file
|
||||
customPatternsDir := os.Getenv("CUSTOM_PATTERNS_DIRECTORY")
|
||||
if customPatternsDir != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(customPatternsDir, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
customPatternsDir = filepath.Join(homeDir, customPatternsDir[2:])
|
||||
}
|
||||
}
|
||||
o.Patterns.CustomPatternsDir = customPatternsDir
|
||||
}
|
||||
|
||||
if err = o.Patterns.Configure(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
@@ -16,6 +17,7 @@ type PatternsEntity struct {
|
||||
*StorageEntity
|
||||
SystemPatternFile string
|
||||
UniquePatternsFilePath string
|
||||
CustomPatternsDir string
|
||||
}
|
||||
|
||||
// Pattern represents a single pattern with its metadata
|
||||
@@ -43,7 +45,7 @@ func (o *PatternsEntity) GetApplyVariables(
|
||||
}
|
||||
|
||||
// Use the resolved absolute path to get the pattern
|
||||
pattern, err = o.getFromFile(absPath)
|
||||
pattern, _ = o.getFromFile(absPath)
|
||||
} else {
|
||||
// Otherwise, get the pattern from the database
|
||||
pattern, err = o.getFromDB(source)
|
||||
@@ -89,6 +91,19 @@ func (o *PatternsEntity) applyVariables(
|
||||
|
||||
// retrieves a pattern from the database by name
|
||||
func (o *PatternsEntity) getFromDB(name string) (ret *Pattern, err error) {
|
||||
// First check custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
customPatternPath := filepath.Join(o.CustomPatternsDir, name, o.SystemPatternFile)
|
||||
if pattern, customErr := os.ReadFile(customPatternPath); customErr == nil {
|
||||
ret = &Pattern{
|
||||
Name: name,
|
||||
Pattern: string(pattern),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to main patterns directory
|
||||
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
||||
|
||||
var pattern []byte
|
||||
@@ -145,6 +160,71 @@ func (o *PatternsEntity) getFromFile(pathStr string) (pattern *Pattern, err erro
|
||||
return
|
||||
}
|
||||
|
||||
// GetNames overrides StorageEntity.GetNames to include custom patterns directory
|
||||
func (o *PatternsEntity) GetNames() (ret []string, err error) {
|
||||
// Get names from main patterns directory
|
||||
mainNames, err := o.StorageEntity.GetNames()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a map to track unique pattern names (custom patterns override main ones)
|
||||
nameMap := make(map[string]bool)
|
||||
for _, name := range mainNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
|
||||
// Get names from custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
// Create a temporary StorageEntity for the custom directory
|
||||
customStorage := &StorageEntity{
|
||||
Dir: o.CustomPatternsDir,
|
||||
ItemIsDir: o.StorageEntity.ItemIsDir,
|
||||
FileExtension: o.StorageEntity.FileExtension,
|
||||
}
|
||||
|
||||
customNames, customErr := customStorage.GetNames()
|
||||
if customErr == nil {
|
||||
// Add custom patterns, they will override main patterns with same name
|
||||
for _, name := range customNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
}
|
||||
// Ignore errors from custom directory (it might not exist)
|
||||
}
|
||||
|
||||
// Convert map keys back to slice
|
||||
ret = make([]string, 0, len(nameMap))
|
||||
for name := range nameMap {
|
||||
ret = append(ret, name)
|
||||
}
|
||||
|
||||
// Sort the patterns alphabetically
|
||||
sort.Strings(ret)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ListNames overrides StorageEntity.ListNames to use PatternsEntity.GetNames
|
||||
func (o *PatternsEntity) ListNames(shellCompleteList bool) (err error) {
|
||||
var names []string
|
||||
if names, err = o.GetNames(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
if !shellCompleteList {
|
||||
fmt.Printf("\nNo %v\n", o.StorageEntity.Label)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, item := range names {
|
||||
fmt.Printf("%s\n", item)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get required for Storage interface
|
||||
func (o *PatternsEntity) Get(name string) (*Pattern, error) {
|
||||
// Use GetPattern with no variables
|
||||
|
||||
@@ -162,3 +162,123 @@ func TestPatternsEntity_Save(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatterns(t *testing.T) {
|
||||
// Create main patterns directory
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
// Create custom patterns directory
|
||||
customDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(customDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: customDir,
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Create a pattern in custom directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "custom-pattern", "Custom pattern content")
|
||||
|
||||
// Create a pattern with same name in both directories (custom should override)
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Main shared pattern")
|
||||
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Custom shared pattern")
|
||||
|
||||
// Test GetNames includes both directories
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
assert.Contains(t, names, "custom-pattern")
|
||||
assert.Contains(t, names, "shared-pattern")
|
||||
|
||||
// Test that custom pattern overrides main pattern
|
||||
pattern, err := entity.getFromDB("shared-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom shared pattern", pattern.Pattern)
|
||||
|
||||
// Test that main pattern is accessible when not overridden
|
||||
pattern, err = entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
|
||||
// Test that custom pattern is accessible
|
||||
pattern, err = entity.getFromDB("custom-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatternsEmpty(t *testing.T) {
|
||||
// Test behavior when custom patterns directory is empty or doesn't exist
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: "/nonexistent/directory",
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Test GetNames works even with nonexistent custom directory
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
|
||||
// Test that main pattern is accessible
|
||||
pattern, err := entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
66
plugins/tools/custom_patterns/custom_patterns.go
Normal file
66
plugins/tools/custom_patterns/custom_patterns.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
)
|
||||
|
||||
func NewCustomPatterns() (ret *CustomPatterns) {
|
||||
label := "Custom Patterns"
|
||||
ret = &CustomPatterns{}
|
||||
|
||||
ret.PluginBase = &plugins.PluginBase{
|
||||
Name: label,
|
||||
SetupDescription: "Custom Patterns - Set directory for your custom patterns (optional)",
|
||||
EnvNamePrefix: plugins.BuildEnvVariablePrefix(label),
|
||||
ConfigureCustom: ret.configure,
|
||||
}
|
||||
|
||||
ret.CustomPatternsDir = ret.AddSetupQuestionCustom("Directory", false,
|
||||
"Enter the path to your custom patterns directory (leave empty to skip)")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type CustomPatterns struct {
|
||||
*plugins.PluginBase
|
||||
CustomPatternsDir *plugins.SetupQuestion
|
||||
}
|
||||
|
||||
func (o *CustomPatterns) configure() error {
|
||||
if o.CustomPatternsDir.Value != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(o.CustomPatternsDir.Value, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
o.CustomPatternsDir.Value = filepath.Join(homeDir, o.CustomPatternsDir.Value[2:])
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to absolute path
|
||||
if absPath, err := filepath.Abs(o.CustomPatternsDir.Value); err == nil {
|
||||
o.CustomPatternsDir.Value = absPath
|
||||
}
|
||||
|
||||
// Check if directory exists, create only if it doesn't
|
||||
if _, err := os.Stat(o.CustomPatternsDir.Value); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(o.CustomPatternsDir.Value, 0755); err != nil {
|
||||
// Log the error but don't clear the value - let it persist in env file
|
||||
fmt.Printf("Warning: Could not create custom patterns directory %s: %v\n", o.CustomPatternsDir.Value, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConfigured returns true if a custom patterns directory has been set
|
||||
func (o *CustomPatterns) IsConfigured() bool {
|
||||
// First configure to load values from environment variables
|
||||
o.Configure()
|
||||
// Check if the plugin has been configured with a directory
|
||||
return o.CustomPatternsDir.Value != ""
|
||||
}
|
||||
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCustomPatterns(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
assert.NotNil(t, plugin)
|
||||
assert.Equal(t, "Custom Patterns", plugin.GetName())
|
||||
assert.Equal(t, "Custom Patterns - Set directory for your custom patterns (optional)", plugin.GetSetupDescription())
|
||||
assert.False(t, plugin.IsConfigured()) // Should not be configured initially
|
||||
}
|
||||
func TestCustomPatterns_Configure(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with empty directory (should work)
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
err := plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with home directory expansion
|
||||
plugin.CustomPatternsDir.Value = "~/test-patterns"
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
expectedPath := filepath.Join(homeDir, "test-patterns")
|
||||
absExpected, _ := filepath.Abs(expectedPath)
|
||||
assert.Equal(t, absExpected, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Clean up
|
||||
os.RemoveAll(plugin.CustomPatternsDir.Value)
|
||||
}
|
||||
|
||||
func TestCustomPatterns_ConfigureWithTempDir(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
plugin.CustomPatternsDir.Value = tmpDir
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
absPath, _ := filepath.Abs(tmpDir)
|
||||
assert.Equal(t, absPath, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(plugin.CustomPatternsDir.Value)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsDir())
|
||||
|
||||
// Should be configured now
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
}
|
||||
|
||||
func TestCustomPatterns_IsConfigured(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Initially not configured
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
|
||||
// Set a directory
|
||||
plugin.CustomPatternsDir.Value = "/some/path"
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
|
||||
// Clear the directory
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
}
|
||||
@@ -57,17 +57,18 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -80,17 +81,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -99,17 +101,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.227"
|
||||
var version = "v1.4.238"
|
||||
|
||||
Reference in New Issue
Block a user