mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c7ceca264 | ||
|
|
c19d7ccd9d | ||
|
|
bd0c5f730e | ||
|
|
5900dac58f | ||
|
|
237219c3cc | ||
|
|
26fd700098 | ||
|
|
6bd926dd0f | ||
|
|
16ac519415 | ||
|
|
a32cc5fa01 | ||
|
|
26b5bb2e9e | ||
|
|
b751d323b1 | ||
|
|
d081fd269c | ||
|
|
369a0a850d | ||
|
|
8dc5343ee6 | ||
|
|
eda552dac5 | ||
|
|
f13a56685b | ||
|
|
2f9afe0247 | ||
|
|
1ec525ad97 | ||
|
|
b7dc6748e0 | ||
|
|
f1b612d828 | ||
|
|
eac5a104f2 | ||
|
|
4bff88fae3 | ||
|
|
acf1be71ce | ||
|
|
236a3c5f38 | ||
|
|
b2418984f8 | ||
|
|
152d74d160 | ||
|
|
4e16bbccd8 | ||
|
|
60174f41a4 | ||
|
|
ad4683952e | ||
|
|
86a044735b | ||
|
|
58583114cb | ||
|
|
36524cd2e4 | ||
|
|
e59156ac2b | ||
|
|
1eac026e92 | ||
|
|
17d863fd57 | ||
|
|
7c9dbfd343 | ||
|
|
d9260bdf26 | ||
|
|
63a0cfeb1e | ||
|
|
12fc6e2000 | ||
|
|
fe5900a5dc | ||
|
|
1b6b8e3d72 | ||
|
|
c85301cb1f | ||
|
|
7cc8226339 | ||
|
|
fc8c4babf8 | ||
|
|
bd809a1f94 | ||
|
|
50aec6291b | ||
|
|
f927fdf40f | ||
|
|
918862ef57 | ||
|
|
d9b8bc3233 | ||
|
|
da29b8e388 | ||
|
|
5e6d4110fa | ||
|
|
4bb090694b | ||
|
|
d232222787 |
80
README.md
80
README.md
@@ -15,9 +15,7 @@ Fabric is graciously supported by…
|
||||
[](https://deepwiki.com/danielmiessler/fabric)
|
||||
|
||||
<div align="center">
|
||||
<p class="align center">
|
||||
<h4><code>fabric</code> is an open-source framework for augmenting humans using AI.</h4>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
[Updates](#updates) •
|
||||
@@ -41,9 +39,9 @@ Since the start of modern AI in late 2022 we've seen an **_extraordinary_** numb
|
||||
|
||||
It's all really exciting and powerful, but _it's not easy to integrate this functionality into our lives._
|
||||
|
||||
<p class="align center">
|
||||
<div class="align center">
|
||||
<h4>In other words, AI doesn't have a capabilities problem—it has an <em>integration</em> problem.</h4>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
**Fabric was created to address this by creating and organizing the fundamental units of AI—the prompts themselves!**
|
||||
|
||||
@@ -95,6 +93,9 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
- [Just use the Patterns](#just-use-the-patterns)
|
||||
- [Prompt Strategies](#prompt-strategies)
|
||||
- [Custom Patterns](#custom-patterns)
|
||||
- [Setting Up Custom Patterns](#setting-up-custom-patterns)
|
||||
- [Using Custom Patterns](#using-custom-patterns)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Helper Apps](#helper-apps)
|
||||
- [`to_pdf`](#to_pdf)
|
||||
- [`to_pdf` Installation](#to_pdf-installation)
|
||||
@@ -114,6 +115,14 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> July 4, 2025
|
||||
>
|
||||
> - **Web Search**: Fabric now supports web search for Anthropic and OpenAI models using the `--search` and `--search-location` flags. This replaces the previous plugin-based search, so you may want to remove the old `ANTHROPIC_WEB_SEARCH_TOOL_*` variables from your `~/.config/fabric/.env` file.
|
||||
> - **Image Generation**: Fabric now has powerful image generation capabilities with OpenAI.
|
||||
> - Generate images from text prompts and save them using `--image-file`.
|
||||
> - Edit existing images by providing an input image with `--attachment`.
|
||||
> - Control image `size`, `quality`, `compression`, and `background` with the new `--image-*` flags.
|
||||
>
|
||||
>June 17, 2025
|
||||
>
|
||||
>- Fabric now supports Perplexity AI. Configure it by using `fabric -S` to add your Perplexity AI API Key,
|
||||
@@ -485,7 +494,6 @@ fabric -h
|
||||
```
|
||||
|
||||
```plaintext
|
||||
|
||||
Usage:
|
||||
fabric [OPTIONS]
|
||||
|
||||
@@ -500,7 +508,9 @@ Application Options:
|
||||
-T, --topp= Set top P (default: 0.9)
|
||||
-s, --stream Stream
|
||||
-P, --presencepenalty= Set presence penalty (default: 0.0)
|
||||
-r, --raw Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns.
|
||||
-r, --raw Use the defaults of the model without sending chat options (like
|
||||
temperature etc.) and use the user role instead of the system role for
|
||||
patterns.
|
||||
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
|
||||
-l, --listpatterns List all patterns
|
||||
-L, --listmodels List all available models
|
||||
@@ -514,9 +524,12 @@ Application Options:
|
||||
--output-session Output the entire session (also a temporary one) to the output file
|
||||
-n, --latest= Number of latest patterns to list (default: 0)
|
||||
-d, --changeDefaultModel Change default model
|
||||
-y, --youtube= YouTube video or play list "URL" to grab transcript, comments from it and send to chat or print it put to the console and store it in the output file
|
||||
-y, --youtube= YouTube video or play list "URL" to grab transcript, comments from it
|
||||
and send to chat or print it put to the console and store it in the
|
||||
output file
|
||||
--playlist Prefer playlist over video if both ids are present in the URL
|
||||
--transcript Grab transcript from YouTube video and send to chat (it is used per default).
|
||||
--transcript Grab transcript from YouTube video and send to chat (it is used per
|
||||
default).
|
||||
--transcript-with-timestamps Grab transcript from YouTube video with timestamps and send to chat
|
||||
--comments Grab comments from YouTube video and send to chat
|
||||
--metadata Output video metadata
|
||||
@@ -544,6 +557,14 @@ Application Options:
|
||||
--liststrategies List all strategies
|
||||
--listvendors List all vendors
|
||||
--shell-complete-list Output raw list without headers/formatting (for shell completion)
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI)
|
||||
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
|
||||
--image-file= Save generated image to specified file path (e.g., 'output.png')
|
||||
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)
|
||||
--image-quality= Image quality: low, medium, high, auto (default: auto)
|
||||
--image-compression= Compression level 0-100 for JPEG/WebP formats (default: not set)
|
||||
--image-background= Background type: opaque, transparent (default: opaque, only for
|
||||
PNG/WebP)
|
||||
|
||||
Help Options:
|
||||
-h, --help Show this help message
|
||||
@@ -634,11 +655,48 @@ Use `fabric -S` and select the option to install the strategies in your `~/.conf
|
||||
|
||||
You may want to use Fabric to create your own custom Patterns—but not share them with others. No problem!
|
||||
|
||||
Just make a directory in `~/.config/custompatterns/` (or wherever) and put your `.md` files in there.
|
||||
Fabric now supports a dedicated custom patterns directory that keeps your personal patterns separate from the built-in ones. This means your custom patterns won't be overwritten when you update Fabric's built-in patterns.
|
||||
|
||||
When you're ready to use them, copy them into `~/.config/fabric/patterns/`
|
||||
### Setting Up Custom Patterns
|
||||
|
||||
You can then use them like any other Patterns, but they won't be public unless you explicitly submit them as Pull Requests to the Fabric project. So don't worry—they're private to you.
|
||||
1. Run the Fabric setup:
|
||||
|
||||
```bash
|
||||
fabric --setup
|
||||
```
|
||||
|
||||
2. Select the "Custom Patterns" option from the Tools menu and enter your desired directory path (e.g., `~/my-custom-patterns`)
|
||||
|
||||
3. Fabric will automatically create the directory if it does not exist.
|
||||
|
||||
### Using Custom Patterns
|
||||
|
||||
1. Create your custom pattern directory structure:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/my-custom-patterns/my-analyzer
|
||||
```
|
||||
|
||||
2. Create your pattern file
|
||||
|
||||
```bash
|
||||
echo "You are an expert analyzer of ..." > ~/my-custom-patterns/my-analyzer/system.md
|
||||
```
|
||||
|
||||
3. **Use your custom pattern:**
|
||||
|
||||
```bash
|
||||
fabric --pattern my-analyzer "analyze this text"
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
- **Priority System**: Custom patterns take precedence over built-in patterns with the same name
|
||||
- **Seamless Integration**: Custom patterns appear in `fabric --listpatterns` alongside built-in ones
|
||||
- **Update Safe**: Your custom patterns are never affected by `fabric --updatepatterns`
|
||||
- **Private by Default**: Custom patterns remain private unless you explicitly share them
|
||||
|
||||
Your custom patterns are completely private and won't be affected by Fabric updates!
|
||||
|
||||
## Helper Apps
|
||||
|
||||
|
||||
@@ -270,7 +270,11 @@ func Cli(version string) (err error) {
|
||||
if chatReq.Language == "" {
|
||||
chatReq.Language = registry.Language.DefaultLanguage.Value
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, currentFlags.BuildChatOptions()); err != nil {
|
||||
var chatOptions *common.ChatOptions
|
||||
if chatOptions, err = currentFlags.BuildChatOptions(); err != nil {
|
||||
return
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, chatOptions); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
132
cli/flags.go
132
cli/flags.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/jessevdk/go-flags"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
|
||||
@@ -74,6 +75,13 @@ type Flags struct {
|
||||
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
|
||||
ListVendors bool `long:"listvendors" description:"List all vendors"`
|
||||
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
|
||||
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`
|
||||
ImageQuality string `long:"image-quality" description:"Image quality: low, medium, high, auto (default: auto)"`
|
||||
ImageCompression int `long:"image-compression" description:"Compression level 0-100 for JPEG/WebP formats (default: not set)"`
|
||||
ImageBackground string `long:"image-background" description:"Background type: opaque, transparent (default: opaque, only for PNG/WebP)"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
@@ -254,8 +262,121 @@ func readStdin() (ret string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
// validateImageFile validates the image file path and extension
|
||||
func validateImageFile(imagePath string) error {
|
||||
if imagePath == "" {
|
||||
return nil // No validation needed if no image file specified
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(imagePath); err == nil {
|
||||
return fmt.Errorf("image file already exists: %s", imagePath)
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
|
||||
for _, validExt := range validExtensions {
|
||||
if ext == validExt {
|
||||
return nil // Valid extension found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
|
||||
}
|
||||
|
||||
// validateImageParameters validates image generation parameters
|
||||
func validateImageParameters(imagePath, size, quality, background string, compression int) error {
|
||||
if imagePath == "" {
|
||||
// Check if any image parameters are specified without --image-file
|
||||
if size != "" || quality != "" || background != "" || compression != 0 {
|
||||
return fmt.Errorf("image parameters (--image-size, --image-quality, --image-background, --image-compression) can only be used with --image-file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if size != "" {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
valid := false
|
||||
for _, validSize := range validSizes {
|
||||
if size == validSize {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image size '%s'. Supported sizes: 1024x1024, 1536x1024, 1024x1536, auto", size)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate quality
|
||||
if quality != "" {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
valid := false
|
||||
for _, validQuality := range validQualities {
|
||||
if quality == validQuality {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image quality '%s'. Supported qualities: low, medium, high, auto", quality)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background
|
||||
if background != "" {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
valid := false
|
||||
for _, validBackground := range validBackgrounds {
|
||||
if background == validBackground {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image background '%s'. Supported backgrounds: opaque, transparent", background)
|
||||
}
|
||||
}
|
||||
|
||||
// Get file format for format-specific validations
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
|
||||
// Validate compression (only for jpeg/webp)
|
||||
if compression != 0 { // 0 means not set
|
||||
if ext != ".jpg" && ext != ".jpeg" && ext != ".webp" {
|
||||
return fmt.Errorf("image compression can only be used with JPEG and WebP formats, not %s", ext)
|
||||
}
|
||||
if compression < 0 || compression > 100 {
|
||||
return fmt.Errorf("image compression must be between 0 and 100, got %d", compression)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background transparency (only for png/webp)
|
||||
if background == "transparent" {
|
||||
if ext != ".png" && ext != ".webp" {
|
||||
return fmt.Errorf("transparent background can only be used with PNG and WebP formats, not %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
// Validate image file if specified
|
||||
if err = validateImageFile(o.ImageFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate image parameters
|
||||
if err = validateImageParameters(o.ImageFile, o.ImageSize, o.ImageQuality, o.ImageBackground, o.ImageCompression); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = &common.ChatOptions{
|
||||
Model: o.Model,
|
||||
Temperature: o.Temperature,
|
||||
TopP: o.TopP,
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
@@ -263,6 +384,13 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
Raw: o.Raw,
|
||||
Seed: o.Seed,
|
||||
ModelContextLength: o.ModelContextLength,
|
||||
Search: o.Search,
|
||||
SearchLocation: o.SearchLocation,
|
||||
ImageFile: o.ImageFile,
|
||||
ImageSize: o.ImageSize,
|
||||
ImageQuality: o.ImageQuality,
|
||||
ImageCompression: o.ImageCompression,
|
||||
ImageBackground: o.ImageBackground,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -64,7 +65,8 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -84,7 +86,8 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -164,3 +167,269 @@ model: 123 # should be string
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageFile(t *testing.T) {
|
||||
t.Run("Empty path should be valid", func(t *testing.T) {
|
||||
err := validateImageFile("")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid extensions should pass", func(t *testing.T) {
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
for _, ext := range validExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.NoError(t, err, "Extension %s should be valid", ext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid extensions should fail", func(t *testing.T) {
|
||||
invalidExtensions := []string{".gif", ".bmp", ".tiff", ".svg", ".txt", ""}
|
||||
for _, ext := range invalidExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.Error(t, err, "Extension %s should be invalid", ext)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "test*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
// Validation should fail because file exists
|
||||
err = validateImageFile(tempFile.Name())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
|
||||
t.Run("Non-existing file with valid extension should pass", func(t *testing.T) {
|
||||
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.png")
|
||||
// Make sure the file doesn't exist
|
||||
os.Remove(nonExistentFile)
|
||||
|
||||
err := validateImageFile(nonExistentFile)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
|
||||
t.Run("Valid image file should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/tmp/output.png", options.ImageFile)
|
||||
})
|
||||
|
||||
t.Run("Invalid extension should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.gif",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "existing*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
flags := &Flags{
|
||||
ImageFile: tempFile.Name(),
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageParameters(t *testing.T) {
|
||||
t.Run("No image file and no parameters should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("", "", "", "", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail", func(t *testing.T) {
|
||||
// Test each parameter individually
|
||||
err := validateImageParameters("", "1024x1024", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
|
||||
err = validateImageParameters("", "", "high", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
// Test multiple parameters
|
||||
err = validateImageParameters("", "1024x1024", "high", "transparent", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
})
|
||||
|
||||
t.Run("Valid size values should pass", func(t *testing.T) {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
for _, size := range validSizes {
|
||||
err := validateImageParameters("/tmp/test.png", size, "", "", 0)
|
||||
assert.NoError(t, err, "Size %s should be valid", size)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid size should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "invalid", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("Valid quality values should pass", func(t *testing.T) {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
for _, quality := range validQualities {
|
||||
err := validateImageParameters("/tmp/test.png", "", quality, "", 0)
|
||||
assert.NoError(t, err, "Quality %s should be valid", quality)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid quality should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "invalid", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image quality")
|
||||
})
|
||||
|
||||
t.Run("Valid background values should pass", func(t *testing.T) {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
for _, background := range validBackgrounds {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", background, 0)
|
||||
assert.NoError(t, err, "Background %s should be valid", background)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid background should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "invalid", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image background")
|
||||
})
|
||||
|
||||
t.Run("Compression for JPEG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 75)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "", 50)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for PNG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "", 75)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression can only be used with JPEG and WebP formats")
|
||||
})
|
||||
|
||||
t.Run("Invalid compression range should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 150)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
|
||||
err = validateImageParameters("/tmp/test.jpg", "", "", "", -10)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
})
|
||||
|
||||
t.Run("Transparent background for PNG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for JPEG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transparent background can only be used with PNG and WebP formats")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageParameters(t *testing.T) {
|
||||
t.Run("Valid image parameters should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not set for PNG
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, "/tmp/test.png", options.ImageFile)
|
||||
assert.Equal(t, "1024x1024", options.ImageSize)
|
||||
assert.Equal(t, "high", options.ImageQuality)
|
||||
assert.Equal(t, "transparent", options.ImageBackground)
|
||||
assert.Equal(t, 0, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Invalid image parameters should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "invalid",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("JPEG with compression should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 80,
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, 80, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail in BuildChatOptions", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageSize: "1024x1024", // Image parameter without ImageFile
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -26,6 +26,13 @@ type ChatOptions struct {
|
||||
Seed int
|
||||
ModelContextLength int
|
||||
MaxTokens int
|
||||
Search bool
|
||||
SearchLocation string
|
||||
ImageFile string
|
||||
ImageSize string
|
||||
ImageQuality string
|
||||
ImageCompression int
|
||||
ImageBackground string
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
||||
124
common/oauth_storage.go
Normal file
124
common/oauth_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthToken represents stored OAuth token information
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the token is expired or will expire within the buffer time
|
||||
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
|
||||
if t.ExpiresAt == 0 {
|
||||
return true
|
||||
}
|
||||
bufferTime := time.Duration(bufferMinutes) * time.Minute
|
||||
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
|
||||
}
|
||||
|
||||
// OAuthStorage handles persistent storage of OAuth tokens
|
||||
type OAuthStorage struct {
|
||||
configDir string
|
||||
}
|
||||
|
||||
// NewOAuthStorage creates a new OAuth storage instance
|
||||
func NewOAuthStorage() (*OAuthStorage, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".config", "fabric")
|
||||
|
||||
// Ensure config directory exists
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return &OAuthStorage{configDir: configDir}, nil
|
||||
}
|
||||
|
||||
// GetTokenPath returns the file path for a provider's OAuth token
|
||||
func (s *OAuthStorage) GetTokenPath(provider string) string {
|
||||
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
|
||||
}
|
||||
|
||||
// SaveToken saves an OAuth token to disk with proper permissions
|
||||
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Marshal token to JSON
|
||||
data, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first for atomic operation
|
||||
tempPath := tokenPath + ".tmp"
|
||||
if err := os.WriteFile(tempPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, tokenPath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to save token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToken loads an OAuth token from disk
|
||||
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
return nil, nil // No token stored
|
||||
}
|
||||
|
||||
// Read token file
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal token
|
||||
var token OAuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken removes a stored OAuth token
|
||||
func (s *OAuthStorage) DeleteToken(provider string) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasValidToken checks if a valid (non-expired) token exists for a provider
|
||||
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
|
||||
token, err := s.LoadToken(provider)
|
||||
if err != nil || token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !token.IsExpired(bufferMinutes)
|
||||
}
|
||||
232
common/oauth_storage_test.go
Normal file
232
common/oauth_storage_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOAuthToken_IsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt int64
|
||||
bufferMinutes int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "token not expired",
|
||||
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
bufferMinutes: 5,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "token expired",
|
||||
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "token expires within buffer",
|
||||
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry time",
|
||||
expiresAt: 0,
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &OAuthToken{ExpiresAt: tt.expiresAt}
|
||||
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create storage with custom config dir
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_access_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "test_scope",
|
||||
}
|
||||
|
||||
// Test saving token
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and has correct permissions
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
info, err := os.Stat(tokenPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Token file not created: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
|
||||
}
|
||||
|
||||
// Test loading token
|
||||
loadedToken, err := storage.LoadToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load token: %v", err)
|
||||
}
|
||||
if loadedToken == nil {
|
||||
t.Fatal("Loaded token is nil")
|
||||
}
|
||||
|
||||
// Verify token data
|
||||
if loadedToken.AccessToken != token.AccessToken {
|
||||
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
if loadedToken.RefreshToken != token.RefreshToken {
|
||||
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
|
||||
}
|
||||
if loadedToken.ExpiresAt != token.ExpiresAt {
|
||||
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Try to load non-existent token
|
||||
token, err := storage.LoadToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error loading non-existent token: %v", err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Error("Expected nil token for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_DeleteToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Create and save a token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
}
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token exists
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
t.Fatal("Token file should exist before deletion")
|
||||
}
|
||||
|
||||
// Delete token
|
||||
err = storage.DeleteToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token is deleted
|
||||
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
|
||||
t.Error("Token file should not exist after deletion")
|
||||
}
|
||||
|
||||
// Test deleting non-existent token (should not error)
|
||||
err = storage.DeleteToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Deleting non-existent token should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_HasValidToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test with no token
|
||||
if storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return false when no token exists")
|
||||
}
|
||||
|
||||
// Save valid token
|
||||
validToken := &OAuthToken{
|
||||
AccessToken: "valid_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
}
|
||||
err = storage.SaveToken("test_provider", validToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save valid token: %v", err)
|
||||
}
|
||||
|
||||
// Test with valid token
|
||||
if !storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return true for valid token")
|
||||
}
|
||||
|
||||
// Save expired token
|
||||
expiredToken := &OAuthToken{
|
||||
AccessToken: "expired_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
}
|
||||
err = storage.SaveToken("expired_provider", expiredToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save expired token: %v", err)
|
||||
}
|
||||
|
||||
// Test with expired token
|
||||
if storage.HasValidToken("expired_provider", 5) {
|
||||
t.Error("Should return false for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_GetTokenPath(t *testing.T) {
|
||||
storage := &OAuthStorage{configDir: "/test/config"}
|
||||
|
||||
expected := filepath.Join("/test/config", ".test_provider_oauth")
|
||||
actual := storage.GetTokenPath("test_provider")
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
@@ -96,6 +96,13 @@ _fabric() {
|
||||
'(--api-key)--api-key[API key used to secure server routes]:api-key:' \
|
||||
'(--config)--config[Path to YAML config file]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--version)--version[Print current version]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \
|
||||
'(--image-quality)--image-quality[Image quality]:quality:(low medium high auto)' \
|
||||
'(--image-compression)--image-compression[Compression level 0-100 for JPEG/WebP formats]:compression:' \
|
||||
'(--image-background)--image-background[Background type]:background:(opaque transparent)' \
|
||||
'(--listextensions)--listextensions[List all registered extensions]' \
|
||||
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \
|
||||
|
||||
@@ -13,7 +13,7 @@ _fabric() {
|
||||
_get_comp_words_by_ref -n : cur prev words cword
|
||||
|
||||
# Define all possible options/flags
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --image-size --image-quality --image-compression --image-background --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
|
||||
# Helper function for dynamic completions
|
||||
_fabric_get_list() {
|
||||
@@ -63,12 +63,25 @@ _fabric() {
|
||||
return 0
|
||||
;;
|
||||
# Options requiring file/directory paths
|
||||
-a | --attachment | -o | --output | --config | --addextension)
|
||||
-a | --attachment | -o | --output | --config | --addextension | --image-file)
|
||||
_filedir
|
||||
return 0
|
||||
;;
|
||||
# Image generation options with specific values
|
||||
--image-size)
|
||||
COMPREPLY=($(compgen -W "1024x1024 1536x1024 1024x1536 auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-quality)
|
||||
COMPREPLY=($(compgen -W "low medium high auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-background)
|
||||
COMPREPLY=($(compgen -W "opaque transparent" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
# Options requiring simple arguments (no specific completion logic here)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location | --image-compression)
|
||||
# No specific completion suggestions, user types the value
|
||||
return 0
|
||||
;;
|
||||
|
||||
@@ -60,6 +60,12 @@ complete -c fabric -l printsession -d "Print session" -a "(__fabric_get_sessions
|
||||
complete -c fabric -l address -d "The address to bind the REST API (default: :8080)"
|
||||
complete -c fabric -l api-key -d "API key used to secure server routes"
|
||||
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
|
||||
complete -c fabric -l image-size -d "Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)" -a "1024x1024 1536x1024 1024x1536 auto"
|
||||
complete -c fabric -l image-quality -d "Image quality: low, medium, high, auto (default: auto)" -a "low medium high auto"
|
||||
complete -c fabric -l image-compression -d "Compression level 0-100 for JPEG/WebP formats (default: not set)" -r
|
||||
complete -c fabric -l image-background -d "Background type: opaque, transparent (default: opaque, only for PNG/WebP)" -a "opaque transparent"
|
||||
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
|
||||
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"
|
||||
@@ -84,6 +90,7 @@ complete -c fabric -l metadata -d "Output video metadata"
|
||||
complete -c fabric -l readability -d "Convert HTML input into a clean, readable view"
|
||||
complete -c fabric -l input-has-vars -d "Apply variables to user input"
|
||||
complete -c fabric -l dry-run -d "Show what would be sent to the model without actually sending it"
|
||||
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI)"
|
||||
complete -c fabric -l serve -d "Serve the Fabric Rest API"
|
||||
complete -c fabric -l serveOllama -d "Serve the Fabric Rest API with ollama endpoints"
|
||||
complete -c fabric -l version -d "Print current version"
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins/db/fsdb"
|
||||
"github.com/danielmiessler/fabric/plugins/template"
|
||||
"github.com/danielmiessler/fabric/plugins/tools"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/custom_patterns"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/jina"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/lang"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/youtube"
|
||||
@@ -69,6 +70,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
|
||||
VendorManager: ai.NewVendorsManager(),
|
||||
VendorsAll: ai.NewVendorsManager(),
|
||||
PatternsLoader: tools.NewPatternsLoader(db.Patterns),
|
||||
CustomPatterns: custom_patterns.NewCustomPatterns(),
|
||||
YouTube: youtube.NewYouTube(),
|
||||
Language: lang.NewLanguage(),
|
||||
Jina: jina.NewClient(),
|
||||
@@ -138,6 +140,7 @@ type PluginRegistry struct {
|
||||
VendorsAll *ai.VendorsManager
|
||||
Defaults *tools.Defaults
|
||||
PatternsLoader *tools.PatternsLoader
|
||||
CustomPatterns *custom_patterns.CustomPatterns
|
||||
YouTube *youtube.YouTube
|
||||
Language *lang.Language
|
||||
Jina *jina.Client
|
||||
@@ -151,6 +154,7 @@ func (o *PluginRegistry) SaveEnvFile() (err error) {
|
||||
|
||||
o.Defaults.Settings.FillEnvFileContent(&envFileContent)
|
||||
o.PatternsLoader.SetupFillEnvFileContent(&envFileContent)
|
||||
o.CustomPatterns.SetupFillEnvFileContent(&envFileContent)
|
||||
o.Strategies.SetupFillEnvFileContent(&envFileContent)
|
||||
|
||||
for _, vendor := range o.VendorManager.Vendors {
|
||||
@@ -183,7 +187,7 @@ func (o *PluginRegistry) Setup() (err error) {
|
||||
return vendor
|
||||
})...)
|
||||
|
||||
groupsPlugins.AddGroupItems("Tools", o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
groupsPlugins.AddGroupItems("Tools", o.CustomPatterns, o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
|
||||
for {
|
||||
groupsPlugins.Print(false)
|
||||
|
||||
3
go.mod
3
go.mod
@@ -25,9 +25,9 @@ require (
|
||||
github.com/samber/lo v1.50.0
|
||||
github.com/sgaunet/perplexity-go/v2 v2.8.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -109,7 +109,6 @@ require (
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
|
||||
1
go.sum
1
go.sum
@@ -354,7 +354,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -325,9 +325,6 @@ schema = 3
|
||||
[mod."gopkg.in/warnings.v0"]
|
||||
version = "v0.1.2"
|
||||
hash = "sha256-ATVL9yEmgYbkJ1DkltDGRn/auGAjqGOfjQyBYyUo8s8="
|
||||
[mod."gopkg.in/yaml.v2"]
|
||||
version = "v2.4.0"
|
||||
hash = "sha256-uVEGglIedjOIGZzHW4YwN1VoRSTK8o0eGZqzd+TNdd0="
|
||||
[mod."gopkg.in/yaml.v3"]
|
||||
version = "v3.0.1"
|
||||
hash = "sha256-FqL9TKYJ0XkNwJFnq9j0VvJ5ZUU1RvH/52h/f5bkYAU="
|
||||
|
||||
@@ -1 +1 @@
|
||||
"1.4.224"
|
||||
"1.4.234"
|
||||
|
||||
49
patterns/apply_ul_tags/system.md
Normal file
49
patterns/apply_ul_tags/system.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# IDENTITY
|
||||
|
||||
You are a superintelligent expert on content of all forms, with deep understanding of which topics, categories, themes, and tags apply to any piece of content.
|
||||
|
||||
# GOAL
|
||||
|
||||
Your goal is to output a JSON object called tags, with the following tags applied if the content is significantly about their topic.
|
||||
|
||||
- **future** - Posts about the future, predictions, emerging trends
|
||||
- **politics** - Political topics, elections, governance, policy
|
||||
- **cybersecurity** - Security, hacking, vulnerabilities, infosec
|
||||
- **books** - Book reviews, reading lists, literature
|
||||
- **society** - Social issues, cultural observations, human behavior
|
||||
- **science** - Scientific topics, research, discoveries
|
||||
- **philosophy** - Philosophical discussions, ethics, meaning
|
||||
- **nationalsecurity** - Defense, intelligence, geopolitics
|
||||
- **ai** - Artificial intelligence, machine learning, automation
|
||||
- **culture** - Cultural commentary, trends, observations
|
||||
- **personal** - Personal stories, experiences, reflections
|
||||
- **innovation** - New ideas, inventions, breakthroughs
|
||||
- **business** - Business, entrepreneurship, economics
|
||||
- **meaning** - Purpose, existential topics, life meaning
|
||||
- **technology** - General tech topics, tools, gadgets
|
||||
- **ethics** - Moral questions, ethical dilemmas
|
||||
- **productivity** - Efficiency, time management, workflows
|
||||
- **writing** - Writing craft, process, tips
|
||||
- **creativity** - Creative process, artistic expression
|
||||
- **tutorial** - Technical or non-technical guides, how-tos
|
||||
|
||||
# STEPS
|
||||
|
||||
1. Deeply understand the content and its themes and categories and topics.
|
||||
2. Evaluate the list of tags above.
|
||||
3. Determine which tags apply to the content.
|
||||
4. Output the "tags" JSON object.
|
||||
|
||||
# NOTES
|
||||
|
||||
- It's ok, and quite normal, for multiple tags to apply—which is why this is tags and not categories
|
||||
- All AI posts should have the technology tag, and that's ok. But not all technology posts are about AI, and therefore the AI tag needs to be evaluated separately. That goes for all potentially nested or conflicted tags.
|
||||
- Be a bit conservative in applying tags. If a piece of content is only tangentially related to a tag, don't include it.
|
||||
|
||||
# OUTPUT INSTRUCTIONS
|
||||
|
||||
- Output ONLY the JSON object, and nothing else.
|
||||
|
||||
- That means DO NOT OUTPUT the ```json format indicator. ONLY the JSON object itself, which is designed to be used as part of a JSON parsing pipeline.
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
@@ -16,6 +15,10 @@ import (
|
||||
|
||||
const defaultBaseUrl = "https://api.anthropic.com/"
|
||||
|
||||
const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
ret = &Client{}
|
||||
@@ -28,10 +31,12 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseWebTool = ret.AddSetupQuestionBool("Web Search Tool Enabled", false)
|
||||
ret.WebToolLocation = ret.AddSetupQuestionCustom("Web Search Tool Location", false,
|
||||
"Enter your approximate timezone location for web search (e.g., 'America/Los_Angeles', see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).")
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
if plugins.ParseBoolElseFalse(ret.UseOAuth.Value) {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
} else {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
}
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -49,10 +54,9 @@ func NewClient() (ret *Client) {
|
||||
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseWebTool *plugins.SetupQuestion
|
||||
WebToolLocation *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -61,24 +65,50 @@ type Client struct {
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
baseURL := an.ApiBaseURL.Value
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// As of 2.0beta1, using v2 API endpoint.
|
||||
// https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md#020-beta1-2025-03-25
|
||||
if strings.Contains(baseURL, "-") && !strings.HasSuffix(baseURL, "/v2") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
baseURL = baseURL + "/v2"
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(
|
||||
option.WithAPIKey(an.ApiKey.Value),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
} else {
|
||||
an.client = anthropic.NewClient(option.WithAPIKey(an.ApiKey.Value))
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
opts := []option.RequestOption{}
|
||||
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -127,20 +157,28 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseWebTool.Value) {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
Name: "web_search", // string literal instead of constant
|
||||
Type: "web_search_20250305", // string literal instead of constant
|
||||
CacheControl: anthropic.NewCacheControlEphemeralParam(),
|
||||
// Optional: restrict domains or max uses
|
||||
// AllowedDomains: []string{"wikipedia.org", "openai.com"},
|
||||
// MaxUses: anthropic.Opt[int64](5),
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
if an.WebToolLocation.Value != "" {
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
Name: webSearchToolName,
|
||||
Type: webSearchToolType,
|
||||
CacheControl: anthropic.NewCacheControlEphemeralParam(),
|
||||
}
|
||||
|
||||
if opts.SearchLocation != "" {
|
||||
webTool.UserLocation.Type = "approximate"
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(an.WebToolLocation.Value)
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(opts.SearchLocation)
|
||||
}
|
||||
|
||||
// Wrap it in the union:
|
||||
@@ -165,13 +203,42 @@ func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage,
|
||||
return
|
||||
}
|
||||
|
||||
texts := lo.FilterMap(message.Content, func(block anthropic.ContentBlockUnion, _ int) (ret string, ok bool) {
|
||||
if ok = block.Type == "text" && block.Text != ""; ok {
|
||||
ret = block.Text
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool) // To avoid duplicate citations
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
// Extract citations from this text block
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := fmt.Sprintf("- [%s](%s)", citation.Title, citation.URL)
|
||||
if citation.CitedText != "" {
|
||||
citationText += fmt.Sprintf(" - \"%s\"", citation.CitedText)
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
ret = strings.Join(texts, "")
|
||||
}
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
resultBuilder.WriteString(strings.Join(textParts, ""))
|
||||
|
||||
// Append citations if any were found
|
||||
if len(citations) > 0 {
|
||||
resultBuilder.WriteString("\n\n")
|
||||
resultBuilder.WriteString(sourcesHeader)
|
||||
resultBuilder.WriteString("\n\n")
|
||||
resultBuilder.WriteString(strings.Join(citations, "\n"))
|
||||
}
|
||||
ret = resultBuilder.String()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -184,6 +251,9 @@ func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropi
|
||||
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
var systemContent string
|
||||
|
||||
// Note: Claude Code spoofing is now handled in buildMessageParams
|
||||
|
||||
isFirstUserMessage := true
|
||||
lastRoleWasUser := false
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
// Test generated using Keploy
|
||||
@@ -63,3 +67,192 @@ func TestClient_ListModels_ReturnsCorrectModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithoutSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: false,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools != nil {
|
||||
t.Error("Expected no tools when search is disabled, got tools")
|
||||
}
|
||||
|
||||
if params.Model != anthropic.Model(opts.Model) {
|
||||
t.Errorf("Expected model %s, got %s", opts.Model, params.Model)
|
||||
}
|
||||
|
||||
if params.Temperature.Value != opts.Temperature {
|
||||
t.Errorf("Expected temperature %f, got %f", opts.Temperature, params.Temperature.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather today?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
if len(params.Tools) != 1 {
|
||||
t.Errorf("Expected 1 tool, got %d", len(params.Tools))
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.Name != "web_search" {
|
||||
t.Errorf("Expected tool name 'web_search', got %s", webTool.Name)
|
||||
}
|
||||
|
||||
if webTool.Type != "web_search_20250305" {
|
||||
t.Errorf("Expected tool type 'web_search_20250305', got %s", webTool.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearchAndLocation(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in San Francisco?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Type != "approximate" {
|
||||
t.Errorf("Expected location type 'approximate', got %s", webTool.UserLocation.Type)
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Timezone.Value != opts.SearchLocation {
|
||||
t.Errorf("Expected timezone %s, got %s", opts.SearchLocation, webTool.UserLocation.Timezone.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
// Test the citation formatting logic by creating a mock message with citations
|
||||
message := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "Based on recent research, artificial intelligence is advancing rapidly.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research",
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "artificial intelligence is advancing rapidly",
|
||||
},
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://another-source.com/tech-news",
|
||||
Title: "Technology News Today",
|
||||
CitedText: "recent developments in AI",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: " Machine learning models are becoming more sophisticated.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research", // Duplicate URL should be deduplicated
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "machine learning models",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract text and citations using the same logic as the Send method
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool)
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
|
||||
if citation.CitedText != "" {
|
||||
citationText += " - \"" + citation.CitedText + "\""
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(textParts, "")
|
||||
if len(citations) > 0 {
|
||||
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
// Verify the result contains the expected text
|
||||
expectedText := "Based on recent research, artificial intelligence is advancing rapidly. Machine learning models are becoming more sophisticated."
|
||||
if !strings.Contains(result, expectedText) {
|
||||
t.Errorf("Expected result to contain text: %s", expectedText)
|
||||
}
|
||||
|
||||
// Verify citations are included
|
||||
if !strings.Contains(result, "## Sources") {
|
||||
t.Error("Expected result to contain Sources section")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[AI Research Advances 2025](https://example.com/ai-research)") {
|
||||
t.Error("Expected result to contain first citation")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[Technology News Today](https://another-source.com/tech-news)") {
|
||||
t.Error("Expected result to contain second citation")
|
||||
}
|
||||
|
||||
// Verify deduplication - should only have 2 unique citations, not 3
|
||||
citationCount := strings.Count(result, "- [")
|
||||
if citationCount != 2 {
|
||||
t.Errorf("Expected 2 unique citations, got %d", citationCount)
|
||||
}
|
||||
}
|
||||
|
||||
300
plugins/ai/anthropic/oauth.go
Normal file
300
plugins/ai/anthropic/oauth.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -76,6 +76,15 @@ func (c *Client) formatOptions(opts *common.ChatOptions) string {
|
||||
if opts.ModelContextLength != 0 {
|
||||
builder.WriteString(fmt.Sprintf("ModelContextLength: %d\n", opts.ModelContextLength))
|
||||
}
|
||||
if opts.Search {
|
||||
builder.WriteString("Search: enabled\n")
|
||||
if opts.SearchLocation != "" {
|
||||
builder.WriteString(fmt.Sprintf("SearchLocation: %s\n", opts.SearchLocation))
|
||||
}
|
||||
}
|
||||
if opts.ImageFile != "" {
|
||||
builder.WriteString(fmt.Sprintf("ImageFile: %s\n", opts.ImageFile))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -127,12 +128,23 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
func (o *Client) sendResponses(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
// Validate model supports image generation if image file is specified
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
return "", fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
}
|
||||
|
||||
req := o.buildResponseParams(msgs, opts)
|
||||
|
||||
var resp *responses.Response
|
||||
if resp, err = o.ApiClient.Responses.New(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract and save images if requested
|
||||
if err = o.extractAndSaveImages(resp, opts); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ret = o.extractText(resp)
|
||||
return
|
||||
}
|
||||
@@ -182,6 +194,31 @@ func (o *Client) buildResponseParams(
|
||||
},
|
||||
}
|
||||
|
||||
// Add tools if enabled
|
||||
var tools []responses.ToolUnionParam
|
||||
|
||||
// Add web search tool if enabled
|
||||
if opts.Search {
|
||||
webSearchTool := responses.ToolParamOfWebSearchPreview("web_search_preview")
|
||||
|
||||
// Add user location if provided
|
||||
if opts.SearchLocation != "" {
|
||||
webSearchTool.OfWebSearchPreview.UserLocation = responses.WebSearchToolUserLocationParam{
|
||||
Type: "approximate",
|
||||
Timezone: openai.String(opts.SearchLocation),
|
||||
}
|
||||
}
|
||||
|
||||
tools = append(tools, webSearchTool)
|
||||
}
|
||||
|
||||
// Add image generation tool if needed
|
||||
tools = o.addImageGenerationTool(opts, tools)
|
||||
|
||||
if len(tools) > 0 {
|
||||
ret.Tools = tools
|
||||
}
|
||||
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
@@ -232,15 +269,41 @@ func convertMessage(msg chat.ChatCompletionMessage) responses.ResponseInputItemU
|
||||
}
|
||||
|
||||
func (o *Client) extractText(resp *responses.Response) (ret string) {
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool) // To avoid duplicate citations
|
||||
|
||||
for _, item := range resp.Output {
|
||||
if item.Type == "message" {
|
||||
for _, c := range item.Content {
|
||||
if c.Type == "output_text" {
|
||||
ret += c.AsOutputText().Text
|
||||
outputText := c.AsOutputText()
|
||||
textParts = append(textParts, outputText.Text)
|
||||
|
||||
// Extract citations from annotations
|
||||
for _, annotation := range outputText.Annotations {
|
||||
if annotation.Type == "url_citation" {
|
||||
urlCitation := annotation.AsURLCitation()
|
||||
citationKey := urlCitation.URL + "|" + urlCitation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := fmt.Sprintf("- [%s](%s)", urlCitation.Title, urlCitation.URL)
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ret = strings.Join(textParts, "")
|
||||
|
||||
// Append citations if any were found
|
||||
if len(citations) > 0 {
|
||||
ret += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
146
plugins/ai/openai/openai_image.go
Normal file
146
plugins/ai/openai/openai_image.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package openai
|
||||
|
||||
// This file contains helper methods for image generation and processing
|
||||
// using OpenAI's Responses API and Image API.
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/responses"
|
||||
)
|
||||
|
||||
// ImageGenerationResponseType is the type used for image generation calls in responses
|
||||
const ImageGenerationResponseType = "image_generation_call"
|
||||
const ImageGenerationToolType = "image_generation"
|
||||
|
||||
// ImageGenerationSupportedModels lists all models that support image generation
|
||||
var ImageGenerationSupportedModels = []string{
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"o3",
|
||||
}
|
||||
|
||||
// supportsImageGeneration checks if the given model supports the image_generation tool
|
||||
func supportsImageGeneration(model string) bool {
|
||||
for _, supportedModel := range ImageGenerationSupportedModels {
|
||||
if model == supportedModel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getOutputFormatFromExtension determines the API output format based on file extension
|
||||
func getOutputFormatFromExtension(imagePath string) string {
|
||||
if imagePath == "" {
|
||||
return "png" // Default format
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
switch ext {
|
||||
case ".png":
|
||||
return "png"
|
||||
case ".webp":
|
||||
return "webp"
|
||||
case ".jpg":
|
||||
return "jpeg"
|
||||
case ".jpeg":
|
||||
return "jpeg"
|
||||
default:
|
||||
return "png" // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// addImageGenerationTool adds the image generation tool to the request if needed
|
||||
func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []responses.ToolUnionParam) []responses.ToolUnionParam {
|
||||
// Check if the request seems to be asking for image generation
|
||||
if o.shouldUseImageGeneration(opts) {
|
||||
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
|
||||
|
||||
// Build the image generation tool with user parameters
|
||||
imageGenTool := responses.ToolUnionParam{
|
||||
OfImageGeneration: &responses.ToolImageGenerationParam{
|
||||
Type: ImageGenerationToolType,
|
||||
Model: "gpt-image-1",
|
||||
OutputFormat: outputFormat,
|
||||
},
|
||||
}
|
||||
|
||||
// Set quality if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageQuality != "" {
|
||||
imageGenTool.OfImageGeneration.Quality = opts.ImageQuality
|
||||
}
|
||||
|
||||
// Set size if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageSize != "" {
|
||||
imageGenTool.OfImageGeneration.Size = opts.ImageSize
|
||||
}
|
||||
|
||||
// Set background if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageBackground != "" {
|
||||
imageGenTool.OfImageGeneration.Background = opts.ImageBackground
|
||||
}
|
||||
|
||||
// Set compression if specified by user (only for jpeg/webp)
|
||||
if opts.ImageCompression != 0 {
|
||||
imageGenTool.OfImageGeneration.OutputCompression = param.NewOpt(int64(opts.ImageCompression))
|
||||
}
|
||||
|
||||
tools = append(tools, imageGenTool)
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// shouldUseImageGeneration determines if image generation should be enabled
|
||||
// This is a heuristic based on the presence of --image-file flag
|
||||
func (o *Client) shouldUseImageGeneration(opts *common.ChatOptions) bool {
|
||||
return opts.ImageFile != ""
|
||||
}
|
||||
|
||||
// extractAndSaveImages extracts generated images from the response and saves them
|
||||
func (o *Client) extractAndSaveImages(resp *responses.Response, opts *common.ChatOptions) error {
|
||||
if opts.ImageFile == "" {
|
||||
return nil // No image file specified, skip saving
|
||||
}
|
||||
|
||||
// Extract image data from response
|
||||
for _, item := range resp.Output {
|
||||
if item.Type == ImageGenerationResponseType {
|
||||
imageCall := item.AsImageGenerationCall()
|
||||
if imageCall.Status == "completed" && imageCall.Result != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageCall.Result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image data: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(opts.ImageFile)
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save image to file
|
||||
if err := os.WriteFile(opts.ImageFile, imageData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to save image to %s: %w", opts.ImageFile, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Image saved to: %s\n", opts.ImageFile)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
444
plugins/ai/openai/openai_image_test.go
Normal file
444
plugins/ai/openai/openai_image_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/responses"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShouldUseImageGeneration(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
// Test with image file specified
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
assert.True(t, client.shouldUseImageGeneration(opts), "Should use image generation when image file is specified")
|
||||
|
||||
// Test without image file
|
||||
opts = &common.ChatOptions{
|
||||
ImageFile: "",
|
||||
}
|
||||
assert.False(t, client.shouldUseImageGeneration(opts), "Should not use image generation when no image file is specified")
|
||||
}
|
||||
|
||||
func TestAddImageGenerationTool(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
// Test with image generation enabled
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
tools := []responses.ToolUnionParam{}
|
||||
result := client.addImageGenerationTool(opts, tools)
|
||||
|
||||
assert.Len(t, result, 1, "Should add one image generation tool")
|
||||
assert.NotNil(t, result[0].OfImageGeneration, "Should have image generation tool")
|
||||
assert.Equal(t, "image_generation", string(result[0].OfImageGeneration.Type))
|
||||
assert.Equal(t, "gpt-image-1", result[0].OfImageGeneration.Model)
|
||||
assert.Equal(t, "png", result[0].OfImageGeneration.OutputFormat)
|
||||
|
||||
// Test without image generation
|
||||
opts = &common.ChatOptions{
|
||||
ImageFile: "",
|
||||
}
|
||||
tools = []responses.ToolUnionParam{}
|
||||
result = client.addImageGenerationTool(opts, tools)
|
||||
|
||||
assert.Len(t, result, 0, "Should not add image generation tool when not needed")
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithImageGeneration(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-image-1",
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Generate an image of a cat"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when image generation is enabled")
|
||||
|
||||
// Should have image generation tool
|
||||
hasImageTool := false
|
||||
for _, tool := range params.Tools {
|
||||
if tool.OfImageGeneration != nil {
|
||||
hasImageTool = true
|
||||
assert.Equal(t, "image_generation", string(tool.OfImageGeneration.Type))
|
||||
assert.Equal(t, "gpt-image-1", tool.OfImageGeneration.Model)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithBothSearchAndImage(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-image-1",
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Search for cat images and generate one"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when both search and image generation are enabled")
|
||||
assert.Len(t, params.Tools, 2, "Should have both search and image generation tools")
|
||||
|
||||
hasSearchTool := false
|
||||
hasImageTool := false
|
||||
|
||||
for _, tool := range params.Tools {
|
||||
if tool.OfWebSearchPreview != nil {
|
||||
hasSearchTool = true
|
||||
}
|
||||
if tool.OfImageGeneration != nil {
|
||||
hasImageTool = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, hasSearchTool, "Should have web search tool")
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestGetOutputFormatFromExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
imagePath string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG extension",
|
||||
imagePath: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP extension",
|
||||
imagePath: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG extension",
|
||||
imagePath: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG extension",
|
||||
imagePath: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Uppercase PNG extension",
|
||||
imagePath: "/tmp/output.PNG",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Mixed case JPEG extension",
|
||||
imagePath: "/tmp/output.JpEg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
imagePath: "",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "No extension",
|
||||
imagePath: "/tmp/output",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Unsupported extension",
|
||||
imagePath: "/tmp/output.gif",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getOutputFormatFromExtension(tt.imagePath)
|
||||
assert.Equal(t, tt.expectedFormat, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
imageFile string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG file",
|
||||
imageFile: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP file",
|
||||
imageFile: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG file",
|
||||
imageFile: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG file",
|
||||
imageFile: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: tt.imageFile,
|
||||
}
|
||||
|
||||
tools := client.addImageGenerationTool(opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1, "Should have one tool")
|
||||
assert.NotNil(t, tools[0].OfImageGeneration, "Should be image generation tool")
|
||||
assert.Equal(t, tt.expectedFormat, tools[0].OfImageGeneration.OutputFormat, "Output format should match file extension")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsImageGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "gpt-4o supports image generation",
|
||||
model: "gpt-4o",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini supports image generation",
|
||||
model: "gpt-4o-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1 supports image generation",
|
||||
model: "gpt-4.1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-mini supports image generation",
|
||||
model: "gpt-4.1-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-nano supports image generation",
|
||||
model: "gpt-4.1-nano",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o3 supports image generation",
|
||||
model: "o3",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o1 does not support image generation",
|
||||
model: "o1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o1-mini does not support image generation",
|
||||
model: "o1-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o3-mini does not support image generation",
|
||||
model: "o3-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4 does not support image generation",
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-3.5-turbo does not support image generation",
|
||||
model: "gpt-3.5-turbo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model does not support image generation",
|
||||
model: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := supportsImageGeneration(tt.model)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelValidationLogic(t *testing.T) {
|
||||
t.Run("Unsupported model with image file should return validation error", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
err := fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
|
||||
assert.Contains(t, err.Error(), "does not support image generation")
|
||||
assert.Contains(t, err.Error(), "o1-mini")
|
||||
assert.Contains(t, err.Error(), "Supported models:")
|
||||
} else {
|
||||
t.Error("Expected validation to trigger")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Supported model with image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger for supported model")
|
||||
})
|
||||
|
||||
t.Run("Unsupported model without image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "", // No image file
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger when no image file is specified")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithUserParameters(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *common.ChatOptions
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "All parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not applicable for PNG
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1536x1024",
|
||||
"quality": "high",
|
||||
"background": "transparent",
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JPEG with compression",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 75,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1024x1024",
|
||||
"quality": "medium",
|
||||
"background": "opaque",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": int64(75),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Only some parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.webp",
|
||||
ImageQuality: "low",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"quality": "low",
|
||||
"output_format": "webp",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No parameters specified (defaults)",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tools := client.addImageGenerationTool(tt.opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1)
|
||||
assert.NotNil(t, tools[0].OfImageGeneration)
|
||||
|
||||
tool := tools[0].OfImageGeneration
|
||||
|
||||
// Check required fields
|
||||
assert.Equal(t, "gpt-image-1", tool.Model)
|
||||
assert.Equal(t, tt.expected["output_format"], tool.OutputFormat)
|
||||
|
||||
// Check optional fields
|
||||
if expectedSize, ok := tt.expected["size"]; ok {
|
||||
assert.Equal(t, expectedSize, tool.Size)
|
||||
} else {
|
||||
assert.Empty(t, tool.Size, "Size should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedQuality, ok := tt.expected["quality"]; ok {
|
||||
assert.Equal(t, expectedQuality, tool.Quality)
|
||||
} else {
|
||||
assert.Empty(t, tool.Quality, "Quality should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedBackground, ok := tt.expected["background"]; ok {
|
||||
assert.Equal(t, expectedBackground, tool.Background)
|
||||
} else {
|
||||
assert.Empty(t, tool.Background, "Background should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedCompression, ok := tt.expected["output_compression"]; ok {
|
||||
assert.Equal(t, expectedCompression, tool.OutputCompression.Value)
|
||||
} else {
|
||||
assert.Equal(t, int64(0), tool.OutputCompression.Value, "Compression should not be set when not specified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/responses"
|
||||
"github.com/openai/openai-go/shared"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -60,3 +62,116 @@ func TestBuildResponseRequestNoMaxTokens(t *testing.T) {
|
||||
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
|
||||
assert.False(t, request.MaxOutputTokens.Valid())
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithoutSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: false,
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.Nil(t, params.Tools, "Expected no tools when search is disabled")
|
||||
assert.Equal(t, shared.ResponsesModel(opts.Model), params.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), params.Temperature)
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "What's the weather today?"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when search is enabled")
|
||||
assert.Len(t, params.Tools, 1, "Expected exactly one tool")
|
||||
|
||||
tool := params.Tools[0]
|
||||
assert.NotNil(t, tool.OfWebSearchPreview, "Expected web search tool")
|
||||
assert.Equal(t, responses.WebSearchToolType("web_search_preview"), tool.OfWebSearchPreview.Type)
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithSearchAndLocation(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "What's the weather in San Francisco?"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when search is enabled")
|
||||
tool := params.Tools[0]
|
||||
assert.NotNil(t, tool.OfWebSearchPreview, "Expected web search tool")
|
||||
|
||||
userLocation := tool.OfWebSearchPreview.UserLocation
|
||||
assert.Equal(t, "approximate", string(userLocation.Type))
|
||||
assert.True(t, userLocation.Timezone.Valid(), "Expected timezone to be set")
|
||||
assert.Equal(t, opts.SearchLocation, userLocation.Timezone.Value)
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
// Test the citation formatting logic by simulating the citation extraction
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool)
|
||||
|
||||
// Simulate text content
|
||||
textParts = append(textParts, "Based on recent research, artificial intelligence is advancing rapidly.")
|
||||
|
||||
// Simulate citations (as they would be extracted from OpenAI response)
|
||||
mockCitations := []struct {
|
||||
URL string
|
||||
Title string
|
||||
}{
|
||||
{"https://example.com/ai-research", "AI Research Advances 2025"},
|
||||
{"https://another-source.com/tech-news", "Technology News Today"},
|
||||
{"https://example.com/ai-research", "AI Research Advances 2025"}, // Duplicate to test deduplication
|
||||
}
|
||||
|
||||
for _, citation := range mockCitations {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(textParts, "")
|
||||
if len(citations) > 0 {
|
||||
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
// Verify the result contains the expected text
|
||||
expectedText := "Based on recent research, artificial intelligence is advancing rapidly."
|
||||
assert.Contains(t, result, expectedText, "Expected result to contain original text")
|
||||
|
||||
// Verify citations are included
|
||||
assert.Contains(t, result, "## Sources", "Expected result to contain Sources section")
|
||||
assert.Contains(t, result, "[AI Research Advances 2025](https://example.com/ai-research)", "Expected result to contain first citation")
|
||||
assert.Contains(t, result, "[Technology News Today](https://another-source.com/tech-news)", "Expected result to contain second citation")
|
||||
|
||||
// Verify deduplication - should only have 2 unique citations, not 3
|
||||
citationCount := strings.Count(result, "- [")
|
||||
assert.Equal(t, 2, citationCount, "Expected 2 unique citations")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
@@ -19,6 +20,7 @@ func NewDb(dir string) (db *Db) {
|
||||
StorageEntity: &StorageEntity{Label: "Patterns", Dir: db.FilePath("patterns"), ItemIsDir: true},
|
||||
SystemPatternFile: "system.md",
|
||||
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
||||
CustomPatternsDir: "", // Will be set after loading .env file
|
||||
}
|
||||
|
||||
db.Sessions = &SessionsEntity{
|
||||
@@ -49,6 +51,18 @@ func (o *Db) Configure() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set custom patterns directory after loading .env file
|
||||
customPatternsDir := os.Getenv("CUSTOM_PATTERNS_DIRECTORY")
|
||||
if customPatternsDir != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(customPatternsDir, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
customPatternsDir = filepath.Join(homeDir, customPatternsDir[2:])
|
||||
}
|
||||
}
|
||||
o.Patterns.CustomPatternsDir = customPatternsDir
|
||||
}
|
||||
|
||||
if err = o.Patterns.Configure(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
@@ -16,6 +17,7 @@ type PatternsEntity struct {
|
||||
*StorageEntity
|
||||
SystemPatternFile string
|
||||
UniquePatternsFilePath string
|
||||
CustomPatternsDir string
|
||||
}
|
||||
|
||||
// Pattern represents a single pattern with its metadata
|
||||
@@ -43,7 +45,7 @@ func (o *PatternsEntity) GetApplyVariables(
|
||||
}
|
||||
|
||||
// Use the resolved absolute path to get the pattern
|
||||
pattern, err = o.getFromFile(absPath)
|
||||
pattern, _ = o.getFromFile(absPath)
|
||||
} else {
|
||||
// Otherwise, get the pattern from the database
|
||||
pattern, err = o.getFromDB(source)
|
||||
@@ -89,6 +91,19 @@ func (o *PatternsEntity) applyVariables(
|
||||
|
||||
// retrieves a pattern from the database by name
|
||||
func (o *PatternsEntity) getFromDB(name string) (ret *Pattern, err error) {
|
||||
// First check custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
customPatternPath := filepath.Join(o.CustomPatternsDir, name, o.SystemPatternFile)
|
||||
if pattern, customErr := os.ReadFile(customPatternPath); customErr == nil {
|
||||
ret = &Pattern{
|
||||
Name: name,
|
||||
Pattern: string(pattern),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to main patterns directory
|
||||
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
||||
|
||||
var pattern []byte
|
||||
@@ -145,6 +160,71 @@ func (o *PatternsEntity) getFromFile(pathStr string) (pattern *Pattern, err erro
|
||||
return
|
||||
}
|
||||
|
||||
// GetNames overrides StorageEntity.GetNames to include custom patterns directory
|
||||
func (o *PatternsEntity) GetNames() (ret []string, err error) {
|
||||
// Get names from main patterns directory
|
||||
mainNames, err := o.StorageEntity.GetNames()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a map to track unique pattern names (custom patterns override main ones)
|
||||
nameMap := make(map[string]bool)
|
||||
for _, name := range mainNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
|
||||
// Get names from custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
// Create a temporary StorageEntity for the custom directory
|
||||
customStorage := &StorageEntity{
|
||||
Dir: o.CustomPatternsDir,
|
||||
ItemIsDir: o.StorageEntity.ItemIsDir,
|
||||
FileExtension: o.StorageEntity.FileExtension,
|
||||
}
|
||||
|
||||
customNames, customErr := customStorage.GetNames()
|
||||
if customErr == nil {
|
||||
// Add custom patterns, they will override main patterns with same name
|
||||
for _, name := range customNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
}
|
||||
// Ignore errors from custom directory (it might not exist)
|
||||
}
|
||||
|
||||
// Convert map keys back to slice
|
||||
ret = make([]string, 0, len(nameMap))
|
||||
for name := range nameMap {
|
||||
ret = append(ret, name)
|
||||
}
|
||||
|
||||
// Sort the patterns alphabetically
|
||||
sort.Strings(ret)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ListNames overrides StorageEntity.ListNames to use PatternsEntity.GetNames
|
||||
func (o *PatternsEntity) ListNames(shellCompleteList bool) (err error) {
|
||||
var names []string
|
||||
if names, err = o.GetNames(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
if !shellCompleteList {
|
||||
fmt.Printf("\nNo %v\n", o.StorageEntity.Label)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, item := range names {
|
||||
fmt.Printf("%s\n", item)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get required for Storage interface
|
||||
func (o *PatternsEntity) Get(name string) (*Pattern, error) {
|
||||
// Use GetPattern with no variables
|
||||
|
||||
@@ -162,3 +162,123 @@ func TestPatternsEntity_Save(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatterns(t *testing.T) {
|
||||
// Create main patterns directory
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
// Create custom patterns directory
|
||||
customDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(customDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: customDir,
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Create a pattern in custom directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "custom-pattern", "Custom pattern content")
|
||||
|
||||
// Create a pattern with same name in both directories (custom should override)
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Main shared pattern")
|
||||
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Custom shared pattern")
|
||||
|
||||
// Test GetNames includes both directories
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
assert.Contains(t, names, "custom-pattern")
|
||||
assert.Contains(t, names, "shared-pattern")
|
||||
|
||||
// Test that custom pattern overrides main pattern
|
||||
pattern, err := entity.getFromDB("shared-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom shared pattern", pattern.Pattern)
|
||||
|
||||
// Test that main pattern is accessible when not overridden
|
||||
pattern, err = entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
|
||||
// Test that custom pattern is accessible
|
||||
pattern, err = entity.getFromDB("custom-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatternsEmpty(t *testing.T) {
|
||||
// Test behavior when custom patterns directory is empty or doesn't exist
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: "/nonexistent/directory",
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Test GetNames works even with nonexistent custom directory
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
|
||||
// Test that main pattern is accessible
|
||||
pattern, err := entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
@@ -152,7 +152,6 @@ func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseBoolElseFalse(val string) (ret bool) {
|
||||
@@ -279,7 +278,6 @@ func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
for _, setting := range o {
|
||||
setting.FillEnvFileContent(buffer)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type SetupQuestions []*SetupQuestion
|
||||
|
||||
64
plugins/tools/custom_patterns/custom_patterns.go
Normal file
64
plugins/tools/custom_patterns/custom_patterns.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
)
|
||||
|
||||
func NewCustomPatterns() (ret *CustomPatterns) {
|
||||
label := "Custom Patterns"
|
||||
ret = &CustomPatterns{}
|
||||
|
||||
ret.PluginBase = &plugins.PluginBase{
|
||||
Name: label,
|
||||
SetupDescription: "Custom Patterns - Set directory for your custom patterns (optional)",
|
||||
EnvNamePrefix: plugins.BuildEnvVariablePrefix(label),
|
||||
ConfigureCustom: ret.configure,
|
||||
}
|
||||
|
||||
ret.CustomPatternsDir = ret.AddSetupQuestionCustom("Directory", false,
|
||||
"Enter the path to your custom patterns directory (leave empty to skip)")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type CustomPatterns struct {
|
||||
*plugins.PluginBase
|
||||
CustomPatternsDir *plugins.SetupQuestion
|
||||
}
|
||||
|
||||
func (o *CustomPatterns) configure() error {
|
||||
if o.CustomPatternsDir.Value != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(o.CustomPatternsDir.Value, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
o.CustomPatternsDir.Value = filepath.Join(homeDir, o.CustomPatternsDir.Value[2:])
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to absolute path
|
||||
if absPath, err := filepath.Abs(o.CustomPatternsDir.Value); err == nil {
|
||||
o.CustomPatternsDir.Value = absPath
|
||||
}
|
||||
|
||||
// Check if directory exists, create only if it doesn't
|
||||
if _, err := os.Stat(o.CustomPatternsDir.Value); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(o.CustomPatternsDir.Value, 0755); err != nil {
|
||||
// Log the error but don't clear the value - let it persist in env file
|
||||
fmt.Printf("Warning: Could not create custom patterns directory %s: %v\n", o.CustomPatternsDir.Value, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConfigured returns true if a custom patterns directory has been set
|
||||
func (o *CustomPatterns) IsConfigured() bool {
|
||||
// Check if the plugin has been configured with a directory
|
||||
return o.CustomPatternsDir.Value != ""
|
||||
}
|
||||
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCustomPatterns(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
assert.NotNil(t, plugin)
|
||||
assert.Equal(t, "Custom Patterns", plugin.GetName())
|
||||
assert.Equal(t, "Custom Patterns - Set directory for your custom patterns (optional)", plugin.GetSetupDescription())
|
||||
assert.False(t, plugin.IsConfigured()) // Should not be configured initially
|
||||
}
|
||||
func TestCustomPatterns_Configure(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with empty directory (should work)
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
err := plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with home directory expansion
|
||||
plugin.CustomPatternsDir.Value = "~/test-patterns"
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
expectedPath := filepath.Join(homeDir, "test-patterns")
|
||||
absExpected, _ := filepath.Abs(expectedPath)
|
||||
assert.Equal(t, absExpected, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Clean up
|
||||
os.RemoveAll(plugin.CustomPatternsDir.Value)
|
||||
}
|
||||
|
||||
func TestCustomPatterns_ConfigureWithTempDir(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
plugin.CustomPatternsDir.Value = tmpDir
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
absPath, _ := filepath.Abs(tmpDir)
|
||||
assert.Equal(t, absPath, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(plugin.CustomPatternsDir.Value)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsDir())
|
||||
|
||||
// Should be configured now
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
}
|
||||
|
||||
func TestCustomPatterns_IsConfigured(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Initially not configured
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
|
||||
// Set a directory
|
||||
plugin.CustomPatternsDir.Value = "/some/path"
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
|
||||
// Clear the directory
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
}
|
||||
@@ -57,17 +57,18 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -80,17 +81,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -99,17 +101,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.224"
|
||||
var version = "v1.4.234"
|
||||
|
||||
Reference in New Issue
Block a user