mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a40bacaf34 | ||
|
|
969b85380c | ||
|
|
e8fe4434db | ||
|
|
7c7ceca264 | ||
|
|
c19d7ccd9d | ||
|
|
bd0c5f730e | ||
|
|
5900dac58f | ||
|
|
237219c3cc | ||
|
|
26fd700098 | ||
|
|
6bd926dd0f | ||
|
|
16ac519415 | ||
|
|
a32cc5fa01 | ||
|
|
26b5bb2e9e | ||
|
|
b751d323b1 | ||
|
|
d081fd269c | ||
|
|
369a0a850d | ||
|
|
8dc5343ee6 | ||
|
|
eda552dac5 | ||
|
|
f13a56685b | ||
|
|
2f9afe0247 | ||
|
|
1ec525ad97 | ||
|
|
b7dc6748e0 | ||
|
|
f1b612d828 | ||
|
|
eac5a104f2 | ||
|
|
4bff88fae3 | ||
|
|
acf1be71ce | ||
|
|
236a3c5f38 | ||
|
|
b2418984f8 | ||
|
|
152d74d160 | ||
|
|
4e16bbccd8 | ||
|
|
60174f41a4 | ||
|
|
ad4683952e | ||
|
|
86a044735b | ||
|
|
58583114cb | ||
|
|
36524cd2e4 | ||
|
|
e59156ac2b | ||
|
|
1eac026e92 | ||
|
|
17d863fd57 | ||
|
|
7c9dbfd343 | ||
|
|
d9260bdf26 | ||
|
|
63a0cfeb1e | ||
|
|
12fc6e2000 | ||
|
|
fe5900a5dc | ||
|
|
1b6b8e3d72 | ||
|
|
c85301cb1f | ||
|
|
7cc8226339 | ||
|
|
fc8c4babf8 | ||
|
|
bd809a1f94 | ||
|
|
50aec6291b | ||
|
|
f927fdf40f | ||
|
|
918862ef57 | ||
|
|
d9b8bc3233 | ||
|
|
da29b8e388 | ||
|
|
5e6d4110fa | ||
|
|
4bb090694b | ||
|
|
d232222787 | ||
|
|
095890a556 | ||
|
|
64c1fe18ef | ||
|
|
1cea32a677 | ||
|
|
49658a3214 | ||
|
|
f236cab276 | ||
|
|
5e0aaa1f93 | ||
|
|
eb16806931 | ||
|
|
474dd786a4 | ||
|
|
edad63df19 | ||
|
|
c7eb7439ef | ||
|
|
23d678d62f | ||
|
|
de5260a661 | ||
|
|
baeadc2270 | ||
|
|
5b4cec81c3 | ||
|
|
eda5531087 | ||
|
|
66925d188a | ||
|
|
6179742e79 | ||
|
|
d8fc6940f0 | ||
|
|
44f7e8dfef | ||
|
|
c5ada714ff | ||
|
|
80c4807f7e | ||
|
|
b4126b6798 | ||
|
|
f2ffa64af9 | ||
|
|
09e01eddf4 |
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -2,7 +2,7 @@ name: Go Release
|
||||
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [ tag_created ]
|
||||
types: [tag_created]
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
@@ -108,6 +108,7 @@ jobs:
|
||||
Add-Content -Path $env:GITHUB_ENV -Value "latest_tag=$latest_tag"
|
||||
|
||||
- name: Create release if it doesn't exist
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
|
||||
@@ -5,6 +5,7 @@ This document explains the complete workflow for managing pattern descriptions a
|
||||
## System Overview
|
||||
|
||||
The pattern system follows this hierarchy:
|
||||
|
||||
1. `~/.config/fabric/patterns/` directory: The source of truth for available patterns
|
||||
2. `pattern_extracts.json`: Contains first 500 words of each pattern for reference
|
||||
3. `pattern_descriptions.json`: Stores pattern metadata (descriptions and tags)
|
||||
@@ -13,17 +14,21 @@ The pattern system follows this hierarchy:
|
||||
## Pattern Processing Workflow
|
||||
|
||||
### 1. Adding New Patterns
|
||||
|
||||
- Add patterns to `~/.config/fabric/patterns/`
|
||||
- Run extract_patterns.py to process new additions:
|
||||
|
||||
```bash
|
||||
python extract_patterns.py
|
||||
|
||||
The Python Script automatically:
|
||||
|
||||
- Creates pattern extracts for reference
|
||||
- Adds placeholder entries in descriptions file
|
||||
- Syncs to web interface
|
||||
|
||||
### 2. Pattern Extract Creation
|
||||
|
||||
The script extracts first 500 words from each pattern's system.md file to:
|
||||
|
||||
- Provide context for writing descriptions
|
||||
@@ -31,8 +36,8 @@ The script extracts first 500 words from each pattern's system.md file to:
|
||||
- Aid in pattern categorization
|
||||
|
||||
### 3. Description and Tag Management
|
||||
Pattern descriptions and tags are managed in pattern_descriptions.json:
|
||||
|
||||
Pattern descriptions and tags are managed in pattern_descriptions.json:
|
||||
|
||||
{
|
||||
"patterns": [
|
||||
@@ -44,20 +49,21 @@ Pattern descriptions and tags are managed in pattern_descriptions.json:
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
## Completing Pattern Metadata
|
||||
|
||||
### Writing Descriptions
|
||||
|
||||
1. Check pattern_descriptions.json for "[Description pending]" entries
|
||||
2. Reference pattern_extracts.json for context
|
||||
|
||||
3. How to update Pattern short descriptions (one sentence).
|
||||
3. How to update Pattern short descriptions (one sentence).
|
||||
|
||||
You can update your descriptions in pattern_descriptions.json manually or using LLM assistance (preferred approach).
|
||||
You can update your descriptions in pattern_descriptions.json manually or using LLM assistance (preferred approach).
|
||||
|
||||
Tell AI to look for "Description pending" entries in this file and write a short description based on the extract info in the pattern_extracts.json file. You can also ask your LLM to add tags for those newly added patterns, using other patterns tag assignments as example.
|
||||
Tell AI to look for "Description pending" entries in this file and write a short description based on the extract info in the pattern_extracts.json file. You can also ask your LLM to add tags for those newly added patterns, using other patterns tag assignments as example.
|
||||
|
||||
### Managing Tags
|
||||
|
||||
1. Add appropriate tags to new patterns
|
||||
2. Update existing tags as needed
|
||||
3. Tags are stored as arrays: ["TAG1", "TAG2"]
|
||||
@@ -67,6 +73,7 @@ Tell AI to look for "Description pending" entries in this file and write a short
|
||||
## File Synchronization
|
||||
|
||||
The script maintains synchronization between:
|
||||
|
||||
- Local pattern_descriptions.json
|
||||
- Web interface copy in static/data/
|
||||
- No manual file copying needed
|
||||
@@ -91,6 +98,7 @@ The script maintains synchronization between:
|
||||
## Troubleshooting
|
||||
|
||||
If patterns are not showing in the web interface:
|
||||
|
||||
1. Verify pattern_descriptions.json format
|
||||
2. Check web static copy exists
|
||||
3. Ensure proper file permissions
|
||||
@@ -108,17 +116,3 @@ fabric/
|
||||
└── static/
|
||||
└── data/
|
||||
└── pattern_descriptions.json # Web interface copy
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
106
Pattern_Descriptions/extract_patterns.py
Normal file → Executable file
106
Pattern_Descriptions/extract_patterns.py
Normal file → Executable file
@@ -1,81 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Extracts pattern information from the ~/.config/fabric/patterns directory,
|
||||
creates JSON files for pattern extracts and descriptions, and updates web static files.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
|
||||
|
||||
def load_existing_file(filepath):
|
||||
"""Load existing JSON file or return default structure"""
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"Warning: Malformed JSON in {filepath}. Starting with an empty list."
|
||||
)
|
||||
return {"patterns": []}
|
||||
return {"patterns": []}
|
||||
|
||||
|
||||
def get_pattern_extract(pattern_path):
|
||||
"""Extract first 500 words from pattern's system.md file"""
|
||||
system_md_path = os.path.join(pattern_path, "system.md")
|
||||
with open(system_md_path, 'r', encoding='utf-8') as f:
|
||||
content = ' '.join(f.read().split()[:500])
|
||||
with open(system_md_path, "r", encoding="utf-8") as f:
|
||||
content = " ".join(f.read().split()[:500])
|
||||
return content
|
||||
|
||||
|
||||
def extract_pattern_info():
|
||||
"""Extract pattern information from the patterns directory"""
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
patterns_dir = os.path.expanduser("~/.config/fabric/patterns")
|
||||
print(f"\nScanning patterns directory: {patterns_dir}")
|
||||
|
||||
|
||||
extracts_path = os.path.join(script_dir, "pattern_extracts.json")
|
||||
descriptions_path = os.path.join(script_dir, "pattern_descriptions.json")
|
||||
|
||||
|
||||
existing_extracts = load_existing_file(extracts_path)
|
||||
existing_descriptions = load_existing_file(descriptions_path)
|
||||
|
||||
|
||||
existing_extract_names = {p["patternName"] for p in existing_extracts["patterns"]}
|
||||
existing_description_names = {p["patternName"] for p in existing_descriptions["patterns"]}
|
||||
existing_description_names = {
|
||||
p["patternName"] for p in existing_descriptions["patterns"]
|
||||
}
|
||||
print(f"Found existing patterns: {len(existing_extract_names)}")
|
||||
|
||||
|
||||
new_extracts = []
|
||||
new_descriptions = []
|
||||
|
||||
|
||||
|
||||
for dirname in sorted(os.listdir(patterns_dir)):
|
||||
# Only log new pattern processing
|
||||
if dirname not in existing_extract_names:
|
||||
print(f"Processing new pattern: {dirname}")
|
||||
|
||||
|
||||
pattern_path = os.path.join(patterns_dir, dirname)
|
||||
system_md_path = os.path.join(pattern_path, "system.md")
|
||||
print(f"Checking system.md at: {system_md_path}")
|
||||
|
||||
|
||||
if os.path.isdir(pattern_path) and os.path.exists(system_md_path):
|
||||
print(f"Valid pattern directory found: {dirname}")
|
||||
if dirname not in existing_extract_names:
|
||||
print(f"Processing new pattern: {dirname}")
|
||||
|
||||
try:
|
||||
if dirname not in existing_extract_names:
|
||||
print(f"Creating new extract for: {dirname}")
|
||||
pattern_extract = get_pattern_extract(pattern_path) # Pass directory path
|
||||
new_extracts.append({
|
||||
"patternName": dirname,
|
||||
"pattern_extract": pattern_extract
|
||||
})
|
||||
|
||||
pattern_extract = get_pattern_extract(
|
||||
pattern_path
|
||||
) # Pass directory path
|
||||
new_extracts.append(
|
||||
{"patternName": dirname, "pattern_extract": pattern_extract}
|
||||
)
|
||||
|
||||
if dirname not in existing_description_names:
|
||||
print(f"Creating new description for: {dirname}")
|
||||
new_descriptions.append({
|
||||
"patternName": dirname,
|
||||
"description": "[Description pending]",
|
||||
"tags": []
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
new_descriptions.append(
|
||||
{
|
||||
"patternName": dirname,
|
||||
"description": "[Description pending]",
|
||||
"tags": [],
|
||||
}
|
||||
)
|
||||
|
||||
except OSError as e:
|
||||
print(f"Error processing {dirname}: {str(e)}")
|
||||
else:
|
||||
print(f"Invalid pattern directory or missing system.md: {dirname}")
|
||||
|
||||
print(f"\nProcessing summary:")
|
||||
|
||||
print("\nProcessing summary:")
|
||||
print(f"New extracts created: {len(new_extracts)}")
|
||||
print(f"New descriptions added: {len(new_descriptions)}")
|
||||
|
||||
|
||||
existing_extracts["patterns"].extend(new_extracts)
|
||||
existing_descriptions["patterns"].extend(new_descriptions)
|
||||
|
||||
|
||||
return existing_extracts, existing_descriptions, len(new_descriptions)
|
||||
|
||||
|
||||
@@ -87,28 +102,29 @@ def update_web_static(descriptions_path):
|
||||
static_path = os.path.join(static_dir, "pattern_descriptions.json")
|
||||
shutil.copy2(descriptions_path, static_path)
|
||||
|
||||
|
||||
def save_pattern_files():
|
||||
"""Save both pattern files and sync to web"""
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extracts_path = os.path.join(script_dir, "pattern_extracts.json")
|
||||
descriptions_path = os.path.join(script_dir, "pattern_descriptions.json")
|
||||
|
||||
|
||||
pattern_extracts, pattern_descriptions, new_count = extract_pattern_info()
|
||||
|
||||
|
||||
# Save files
|
||||
with open(extracts_path, 'w', encoding='utf-8') as f:
|
||||
with open(extracts_path, "w", encoding="utf-8") as f:
|
||||
json.dump(pattern_extracts, f, indent=2, ensure_ascii=False)
|
||||
|
||||
with open(descriptions_path, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(descriptions_path, "w", encoding="utf-8") as f:
|
||||
json.dump(pattern_descriptions, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
# Update web static
|
||||
update_web_static(descriptions_path)
|
||||
|
||||
print(f"\nProcessing complete:")
|
||||
|
||||
print("\nProcessing complete:")
|
||||
print(f"Total patterns: {len(pattern_descriptions['patterns'])}")
|
||||
print(f"New patterns added: {new_count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_pattern_files()
|
||||
|
||||
|
||||
@@ -1710,7 +1710,7 @@
|
||||
},
|
||||
{
|
||||
"patternName": "analyze_bill_short",
|
||||
"description": "Consended - Analyze a legislative bill and implications.",
|
||||
"description": "Condensed - Analyze a legislative bill and implications.",
|
||||
"tags": [
|
||||
"ANALYSIS",
|
||||
"BILL"
|
||||
@@ -1815,6 +1815,35 @@
|
||||
"WRITING",
|
||||
"CREATIVITY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "extract_alpha",
|
||||
"description": "Extracts the most novel and surprising ideas (\"alpha\") from content, inspired by information theory.",
|
||||
"tags": [
|
||||
"EXTRACT",
|
||||
"ANALYSIS",
|
||||
"CR THINKING",
|
||||
"WISDOM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "extract_mcp_servers",
|
||||
"description": "Analyzes content to identify and extract detailed information about Model Context Protocol (MCP) servers.",
|
||||
"tags": [
|
||||
"ANALYSIS",
|
||||
"EXTRACT",
|
||||
"DEVELOPMENT",
|
||||
"AI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "review_code",
|
||||
"description": "Performs a comprehensive code review, providing detailed feedback on correctness, security, and performance.",
|
||||
"tags": [
|
||||
"DEVELOPMENT",
|
||||
"REVIEW",
|
||||
"SECURITY"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -883,6 +883,18 @@
|
||||
{
|
||||
"patternName": "write_essay",
|
||||
"pattern_extract": "# Identity and Purpose You are an expert on writing clear, and illuminating essays on the topic of the input provided. # Output Instructions - Write the essay in the style of {{author_name}}, embodying all the qualities that they are known for. - Look up some example Essays by {{author_name}} (Use web search if the tool is available) - Write the essay exactly like {{author_name}} would write it as seen in the examples you find. - Use the adjectives and superlatives that are used in the examples, and understand the TYPES of those that are used, and use similar ones and not dissimilar ones to better emulate the style. - Use the same style, vocabulary level, and sentence structure as {{author_name}}. # Output Format - Output a full, publish-ready essay about the content provided using the instructions above. - Write in {{author_name}}'s natural and clear style, without embellishment. - Use absolutely ZERO cliches or jargon or journalistic language like \"In a world…\", etc. - Do not use cliches or jargon. - Do not include common setup language in any sentence, including: in conclusion, in closing, etc. - Do not output warnings or notes—just the output requested. # INPUT: INPUT:"
|
||||
},
|
||||
{
|
||||
"patternName": "extract_alpha",
|
||||
"pattern_extract": "# IDENTITY You're an expert at finding Alpha in content. # PHILOSOPHY I love the idea of Claude Shannon's information theory where basically the only real information is the stuff that's different and anything that's the same as kind of background noise. I love that idea for novelty and surprise inside of content when I think about a presentation or a talk or a podcast or an essay or anything I'm looking for the net new ideas or the new presentation of ideas for the new frameworks of how to use ideas or combine ideas so I'm looking for a way to capture that inside of content. # INSTRUCTIONS I want you to extract the 24 highest alpha ideas and thoughts and insights and recommendations in this piece of content, and I want you to output them in unformatted marked down in 8-word bullets written in the approachable style of Paul Graham. # INPUT"
|
||||
},
|
||||
{
|
||||
"patternName": "extract_mcp_servers",
|
||||
"pattern_extract": "# IDENTITY and PURPOSE You are an expert at analyzing content related to MCP (Model Context Protocol) servers. You excel at identifying and extracting mentions of MCP servers, their features, capabilities, integrations, and usage patterns. Take a step back and think step-by-step about how to achieve the best results for extracting MCP server information. # STEPS - Read and analyze the entire content carefully - Identify all mentions of MCP servers, including: - Specific MCP server names - Server capabilities and features - Integration details - Configuration examples - Use cases and applications - Installation or setup instructions - API endpoints or methods exposed - Any limitations or requirements # OUTPUT SECTIONS - Output a summary of all MCP servers mentioned with the following sections: ## SERVERS FOUND - List each MCP server found with a 15-word description - Include the server name and its primary purpose - Use bullet points for each server ## SERVER DETAILS For each server found, provide: - **Server Name**: The official name - **Purpose**: Main functionality in 25 words or less - **Key Features**: Up to 5 main features as bullet points - **Integration**: How it integrates with systems (if mentioned) - **Configuration**: Any configuration details mentioned - **Requirements**: Dependencies or requirements (if specified) ## USAGE EXAMPLES - Extract any code snippets or usage examples - Include configuration files or setup instructions - Present each example with context ## INSIGHTS - Provide 3-5 insights about the MCP servers mentioned - Focus on patterns, trends, or notable characteristics - Each insight should be a 20-word bullet point # OUTPUT INSTRUCTIONS - Output in clean, readable Markdown - Use proper heading hierarchy - Include code blocks with appropriate language tags - Do not include warnings or notes about the content - If no MCP servers are found, simply state \"No MCP servers mentioned in the content\" - Ensure all server names are accurately captured - Preserve technical details and specifications # INPUT: INPUT:"
|
||||
},
|
||||
{
|
||||
"patternName": "review_code",
|
||||
"pattern_extract": "# Code Review Task ## ROLE AND GOAL You are a Principal Software Engineer, renowned for your meticulous attention to detail and your ability to provide clear, constructive, and educational code reviews. Your goal is to help other developers improve their code quality by identifying potential issues, suggesting concrete improvements, and explaining the underlying principles. ## TASK You will be given a snippet of code or a diff. Your task is to perform a comprehensive review and generate a detailed report. ## STEPS 1. **Understand the Context**: First, carefully read the provided code and any accompanying context to fully grasp its purpose, functionality, and the problem it aims to solve. 2. **Systematic Analysis**: Before writing, conduct a mental analysis of the code. Evaluate it against the following key aspects. Do not write this analysis in the output; use it to form your review. * **Correctness**: Are there bugs, logic errors, or race conditions? * **Security**: Are there any potential vulnerabilities (e.g., injection attacks, improper handling of sensitive data)? * **Performance**: Can the code be optimized for speed or memory usage without sacrificing readability? * **Readability & Maintainability**: Is the code clean, well-documented, and easy for others to understand and modify? * **Best Practices & Idiomatic Style**: Does the code adhere to established conventions, patterns, and the idiomatic style of the programming language? * **Error Handling & Edge Cases**: Are errors handled gracefully? Have all relevant edge cases been considered? 3. **Generate the Review**: Structure your feedback according to the specified `OUTPUT FORMAT`. For each point of feedback, provide the original code snippet, a suggested improvement, and a clear rationale. ## OUTPUT FORMAT Your review must be in Markdown and follow this exact structure: --- ### Overall Assessment A brief, high-level summary of the code's quality. Mention its strengths and the primary areas for improvement. ### **Prioritized Recommendations** A numbered list of the most important changes, ordered from most to least critical. 1. (Most critical change) 2. (Second most critical change) 3. ... ### **Detailed Feedback** For each issue you identified, provide a detailed breakdown in the following format. --- **[ISSUE TITLE]** - (e.g., `Security`, `Readability`, `Performance`) **Original Code:** ```[language] // The specific lines of code with the issue ``` **Suggested Improvement:** ```[language] // The revised, improved code ``` **Rationale:** A clear and concise explanation of why the change is recommended. Reference best practices, design patterns, or potential risks. If you use advanced concepts, briefly explain them. --- (Repeat this section for each issue) ## EXAMPLE Here is an example of a review for a simple Python function: --- ### **Overall Assessment** The function correctly fetches user data, but it can be made more robust and efficient. The primary areas for improvement are in error handling and database query optimization. ### **Prioritized Recommendations** 1. Avoid making database queries inside a loop to prevent performance issues (N+1 query problem). 2. Add specific error handling for when a user is not found. ### **Detailed Feedback** --- **[PERFORMANCE]** - N+1 Database Query **Original Code:**"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
80
README.md
80
README.md
@@ -15,9 +15,7 @@ Fabric is graciously supported by…
|
||||
[](https://deepwiki.com/danielmiessler/fabric)
|
||||
|
||||
<div align="center">
|
||||
<p class="align center">
|
||||
<h4><code>fabric</code> is an open-source framework for augmenting humans using AI.</h4>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
[Updates](#updates) •
|
||||
@@ -41,9 +39,9 @@ Since the start of modern AI in late 2022 we've seen an **_extraordinary_** numb
|
||||
|
||||
It's all really exciting and powerful, but _it's not easy to integrate this functionality into our lives._
|
||||
|
||||
<p class="align center">
|
||||
<div class="align center">
|
||||
<h4>In other words, AI doesn't have a capabilities problem—it has an <em>integration</em> problem.</h4>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
**Fabric was created to address this by creating and organizing the fundamental units of AI—the prompts themselves!**
|
||||
|
||||
@@ -95,6 +93,9 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
- [Just use the Patterns](#just-use-the-patterns)
|
||||
- [Prompt Strategies](#prompt-strategies)
|
||||
- [Custom Patterns](#custom-patterns)
|
||||
- [Setting Up Custom Patterns](#setting-up-custom-patterns)
|
||||
- [Using Custom Patterns](#using-custom-patterns)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Helper Apps](#helper-apps)
|
||||
- [`to_pdf`](#to_pdf)
|
||||
- [`to_pdf` Installation](#to_pdf-installation)
|
||||
@@ -114,6 +115,14 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> July 4, 2025
|
||||
>
|
||||
> - **Web Search**: Fabric now supports web search for Anthropic and OpenAI models using the `--search` and `--search-location` flags. This replaces the previous plugin-based search, so you may want to remove the old `ANTHROPIC_WEB_SEARCH_TOOL_*` variables from your `~/.config/fabric/.env` file.
|
||||
> - **Image Generation**: Fabric now has powerful image generation capabilities with OpenAI.
|
||||
> - Generate images from text prompts and save them using `--image-file`.
|
||||
> - Edit existing images by providing an input image with `--attachment`.
|
||||
> - Control image `size`, `quality`, `compression`, and `background` with the new `--image-*` flags.
|
||||
>
|
||||
>June 17, 2025
|
||||
>
|
||||
>- Fabric now supports Perplexity AI. Configure it by using `fabric -S` to add your Perplexity AI API Key,
|
||||
@@ -485,7 +494,6 @@ fabric -h
|
||||
```
|
||||
|
||||
```plaintext
|
||||
|
||||
Usage:
|
||||
fabric [OPTIONS]
|
||||
|
||||
@@ -500,7 +508,9 @@ Application Options:
|
||||
-T, --topp= Set top P (default: 0.9)
|
||||
-s, --stream Stream
|
||||
-P, --presencepenalty= Set presence penalty (default: 0.0)
|
||||
-r, --raw Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns.
|
||||
-r, --raw Use the defaults of the model without sending chat options (like
|
||||
temperature etc.) and use the user role instead of the system role for
|
||||
patterns.
|
||||
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
|
||||
-l, --listpatterns List all patterns
|
||||
-L, --listmodels List all available models
|
||||
@@ -514,9 +524,12 @@ Application Options:
|
||||
--output-session Output the entire session (also a temporary one) to the output file
|
||||
-n, --latest= Number of latest patterns to list (default: 0)
|
||||
-d, --changeDefaultModel Change default model
|
||||
-y, --youtube= YouTube video or play list "URL" to grab transcript, comments from it and send to chat or print it put to the console and store it in the output file
|
||||
-y, --youtube= YouTube video or play list "URL" to grab transcript, comments from it
|
||||
and send to chat or print it put to the console and store it in the
|
||||
output file
|
||||
--playlist Prefer playlist over video if both ids are present in the URL
|
||||
--transcript Grab transcript from YouTube video and send to chat (it is used per default).
|
||||
--transcript Grab transcript from YouTube video and send to chat (it is used per
|
||||
default).
|
||||
--transcript-with-timestamps Grab transcript from YouTube video with timestamps and send to chat
|
||||
--comments Grab comments from YouTube video and send to chat
|
||||
--metadata Output video metadata
|
||||
@@ -544,6 +557,14 @@ Application Options:
|
||||
--liststrategies List all strategies
|
||||
--listvendors List all vendors
|
||||
--shell-complete-list Output raw list without headers/formatting (for shell completion)
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI)
|
||||
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
|
||||
--image-file= Save generated image to specified file path (e.g., 'output.png')
|
||||
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)
|
||||
--image-quality= Image quality: low, medium, high, auto (default: auto)
|
||||
--image-compression= Compression level 0-100 for JPEG/WebP formats (default: not set)
|
||||
--image-background= Background type: opaque, transparent (default: opaque, only for
|
||||
PNG/WebP)
|
||||
|
||||
Help Options:
|
||||
-h, --help Show this help message
|
||||
@@ -634,11 +655,48 @@ Use `fabric -S` and select the option to install the strategies in your `~/.conf
|
||||
|
||||
You may want to use Fabric to create your own custom Patterns—but not share them with others. No problem!
|
||||
|
||||
Just make a directory in `~/.config/custompatterns/` (or wherever) and put your `.md` files in there.
|
||||
Fabric now supports a dedicated custom patterns directory that keeps your personal patterns separate from the built-in ones. This means your custom patterns won't be overwritten when you update Fabric's built-in patterns.
|
||||
|
||||
When you're ready to use them, copy them into `~/.config/fabric/patterns/`
|
||||
### Setting Up Custom Patterns
|
||||
|
||||
You can then use them like any other Patterns, but they won't be public unless you explicitly submit them as Pull Requests to the Fabric project. So don't worry—they're private to you.
|
||||
1. Run the Fabric setup:
|
||||
|
||||
```bash
|
||||
fabric --setup
|
||||
```
|
||||
|
||||
2. Select the "Custom Patterns" option from the Tools menu and enter your desired directory path (e.g., `~/my-custom-patterns`)
|
||||
|
||||
3. Fabric will automatically create the directory if it does not exist.
|
||||
|
||||
### Using Custom Patterns
|
||||
|
||||
1. Create your custom pattern directory structure:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/my-custom-patterns/my-analyzer
|
||||
```
|
||||
|
||||
2. Create your pattern file
|
||||
|
||||
```bash
|
||||
echo "You are an expert analyzer of ..." > ~/my-custom-patterns/my-analyzer/system.md
|
||||
```
|
||||
|
||||
3. **Use your custom pattern:**
|
||||
|
||||
```bash
|
||||
fabric --pattern my-analyzer "analyze this text"
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
- **Priority System**: Custom patterns take precedence over built-in patterns with the same name
|
||||
- **Seamless Integration**: Custom patterns appear in `fabric --listpatterns` alongside built-in ones
|
||||
- **Update Safe**: Your custom patterns are never affected by `fabric --updatepatterns`
|
||||
- **Private by Default**: Custom patterns remain private unless you explicitly share them
|
||||
|
||||
Your custom patterns are completely private and won't be affected by Fabric updates!
|
||||
|
||||
## Helper Apps
|
||||
|
||||
|
||||
132
chat/chat.go
Normal file
132
chat/chat.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
ChatMessageRoleSystem = "system"
|
||||
ChatMessageRoleUser = "user"
|
||||
ChatMessageRoleAssistant = "assistant"
|
||||
ChatMessageRoleFunction = "function"
|
||||
ChatMessageRoleTool = "tool"
|
||||
ChatMessageRoleDeveloper = "developer"
|
||||
)
|
||||
|
||||
var ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously")
|
||||
|
||||
type ChatMessagePartType string
|
||||
|
||||
const (
|
||||
ChatMessagePartTypeText ChatMessagePartType = "text"
|
||||
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
|
||||
)
|
||||
|
||||
type ChatMessageImageURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type ChatMessagePart struct {
|
||||
Type ChatMessagePartType `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ToolType string
|
||||
|
||||
const (
|
||||
ToolTypeFunction ToolType = "function"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type ToolType `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
if m.Content != "" && m.MultiContent != nil {
|
||||
return nil, ErrContentFieldsMisused
|
||||
}
|
||||
if len(m.MultiContent) > 0 {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"-"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(bs, &msg); err == nil {
|
||||
*m = ChatCompletionMessage(msg)
|
||||
return nil
|
||||
}
|
||||
multiMsg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
if err := json.Unmarshal(bs, &multiMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
*m = ChatCompletionMessage(multiMsg)
|
||||
return nil
|
||||
}
|
||||
@@ -270,7 +270,11 @@ func Cli(version string) (err error) {
|
||||
if chatReq.Language == "" {
|
||||
chatReq.Language = registry.Language.DefaultLanguage.Value
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, currentFlags.BuildChatOptions()); err != nil {
|
||||
var chatOptions *common.ChatOptions
|
||||
if chatOptions, err = currentFlags.BuildChatOptions(); err != nil {
|
||||
return
|
||||
}
|
||||
if session, err = chatter.Send(chatReq, chatOptions); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
154
cli/flags.go
154
cli/flags.go
@@ -6,15 +6,16 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/jessevdk/go-flags"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
|
||||
@@ -74,6 +75,13 @@ type Flags struct {
|
||||
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
|
||||
ListVendors bool `long:"listvendors" description:"List all vendors"`
|
||||
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
|
||||
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`
|
||||
ImageQuality string `long:"image-quality" description:"Image quality: low, medium, high, auto (default: auto)"`
|
||||
ImageCompression int `long:"image-compression" description:"Compression level 0-100 for JPEG/WebP formats (default: not set)"`
|
||||
ImageBackground string `long:"image-background" description:"Background type: opaque, transparent (default: opaque, only for PNG/WebP)"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
@@ -254,8 +262,121 @@ func readStdin() (ret string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
// validateImageFile validates the image file path and extension
|
||||
func validateImageFile(imagePath string) error {
|
||||
if imagePath == "" {
|
||||
return nil // No validation needed if no image file specified
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(imagePath); err == nil {
|
||||
return fmt.Errorf("image file already exists: %s", imagePath)
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
|
||||
for _, validExt := range validExtensions {
|
||||
if ext == validExt {
|
||||
return nil // Valid extension found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
|
||||
}
|
||||
|
||||
// validateImageParameters validates image generation parameters
|
||||
func validateImageParameters(imagePath, size, quality, background string, compression int) error {
|
||||
if imagePath == "" {
|
||||
// Check if any image parameters are specified without --image-file
|
||||
if size != "" || quality != "" || background != "" || compression != 0 {
|
||||
return fmt.Errorf("image parameters (--image-size, --image-quality, --image-background, --image-compression) can only be used with --image-file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if size != "" {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
valid := false
|
||||
for _, validSize := range validSizes {
|
||||
if size == validSize {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image size '%s'. Supported sizes: 1024x1024, 1536x1024, 1024x1536, auto", size)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate quality
|
||||
if quality != "" {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
valid := false
|
||||
for _, validQuality := range validQualities {
|
||||
if quality == validQuality {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image quality '%s'. Supported qualities: low, medium, high, auto", quality)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background
|
||||
if background != "" {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
valid := false
|
||||
for _, validBackground := range validBackgrounds {
|
||||
if background == validBackground {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid image background '%s'. Supported backgrounds: opaque, transparent", background)
|
||||
}
|
||||
}
|
||||
|
||||
// Get file format for format-specific validations
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
|
||||
// Validate compression (only for jpeg/webp)
|
||||
if compression != 0 { // 0 means not set
|
||||
if ext != ".jpg" && ext != ".jpeg" && ext != ".webp" {
|
||||
return fmt.Errorf("image compression can only be used with JPEG and WebP formats, not %s", ext)
|
||||
}
|
||||
if compression < 0 || compression > 100 {
|
||||
return fmt.Errorf("image compression must be between 0 and 100, got %d", compression)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate background transparency (only for png/webp)
|
||||
if background == "transparent" {
|
||||
if ext != ".png" && ext != ".webp" {
|
||||
return fmt.Errorf("transparent background can only be used with PNG and WebP formats, not %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
|
||||
// Validate image file if specified
|
||||
if err = validateImageFile(o.ImageFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate image parameters
|
||||
if err = validateImageParameters(o.ImageFile, o.ImageSize, o.ImageQuality, o.ImageBackground, o.ImageCompression); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = &common.ChatOptions{
|
||||
Model: o.Model,
|
||||
Temperature: o.Temperature,
|
||||
TopP: o.TopP,
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
@@ -263,6 +384,13 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
Raw: o.Raw,
|
||||
Seed: o.Seed,
|
||||
ModelContextLength: o.ModelContextLength,
|
||||
Search: o.Search,
|
||||
SearchLocation: o.SearchLocation,
|
||||
ImageFile: o.ImageFile,
|
||||
ImageSize: o.ImageSize,
|
||||
ImageQuality: o.ImageQuality,
|
||||
ImageCompression: o.ImageCompression,
|
||||
ImageBackground: o.ImageBackground,
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -278,15 +406,15 @@ func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest, err erro
|
||||
Meta: Meta,
|
||||
}
|
||||
|
||||
var message *goopenai.ChatCompletionMessage
|
||||
var message *chat.ChatCompletionMessage
|
||||
if len(o.Attachments) > 0 {
|
||||
message = &goopenai.ChatCompletionMessage{
|
||||
Role: goopenai.ChatMessageRoleUser,
|
||||
message = &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
}
|
||||
|
||||
if o.Message != "" {
|
||||
message.MultiContent = append(message.MultiContent, goopenai.ChatMessagePart{
|
||||
Type: goopenai.ChatMessagePartTypeText,
|
||||
message.MultiContent = append(message.MultiContent, chat.ChatMessagePart{
|
||||
Type: chat.ChatMessagePartTypeText,
|
||||
Text: strings.TrimSpace(o.Message),
|
||||
})
|
||||
}
|
||||
@@ -309,16 +437,16 @@ func (o *Flags) BuildChatRequest(Meta string) (ret *common.ChatRequest, err erro
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Image)
|
||||
url = &dataURL
|
||||
}
|
||||
message.MultiContent = append(message.MultiContent, goopenai.ChatMessagePart{
|
||||
Type: goopenai.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &goopenai.ChatMessageImageURL{
|
||||
message.MultiContent = append(message.MultiContent, chat.ChatMessagePart{
|
||||
Type: chat.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &chat.ChatMessageImageURL{
|
||||
URL: *url,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if o.Message != "" {
|
||||
message = &goopenai.ChatCompletionMessage{
|
||||
Role: goopenai.ChatMessageRoleUser,
|
||||
message = &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
Content: strings.TrimSpace(o.Message),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -64,7 +65,8 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -84,7 +86,8 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
@@ -164,3 +167,269 @@ model: 123 # should be string
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageFile(t *testing.T) {
|
||||
t.Run("Empty path should be valid", func(t *testing.T) {
|
||||
err := validateImageFile("")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid extensions should pass", func(t *testing.T) {
|
||||
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
|
||||
for _, ext := range validExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.NoError(t, err, "Extension %s should be valid", ext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid extensions should fail", func(t *testing.T) {
|
||||
invalidExtensions := []string{".gif", ".bmp", ".tiff", ".svg", ".txt", ""}
|
||||
for _, ext := range invalidExtensions {
|
||||
filename := "/tmp/test" + ext
|
||||
err := validateImageFile(filename)
|
||||
assert.Error(t, err, "Extension %s should be invalid", ext)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "test*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
// Validation should fail because file exists
|
||||
err = validateImageFile(tempFile.Name())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
|
||||
t.Run("Non-existing file with valid extension should pass", func(t *testing.T) {
|
||||
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.png")
|
||||
// Make sure the file doesn't exist
|
||||
os.Remove(nonExistentFile)
|
||||
|
||||
err := validateImageFile(nonExistentFile)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
|
||||
t.Run("Valid image file should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/tmp/output.png", options.ImageFile)
|
||||
})
|
||||
|
||||
t.Run("Invalid extension should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/output.gif",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image file extension")
|
||||
})
|
||||
|
||||
t.Run("Existing file should fail", func(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tempFile, err := os.CreateTemp("", "existing*.png")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
flags := &Flags{
|
||||
ImageFile: tempFile.Name(),
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image file already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateImageParameters(t *testing.T) {
|
||||
t.Run("No image file and no parameters should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("", "", "", "", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail", func(t *testing.T) {
|
||||
// Test each parameter individually
|
||||
err := validateImageParameters("", "1024x1024", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
|
||||
err = validateImageParameters("", "", "high", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
err = validateImageParameters("", "", "", "", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
|
||||
// Test multiple parameters
|
||||
err = validateImageParameters("", "1024x1024", "high", "transparent", 50)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
})
|
||||
|
||||
t.Run("Valid size values should pass", func(t *testing.T) {
|
||||
validSizes := []string{"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
for _, size := range validSizes {
|
||||
err := validateImageParameters("/tmp/test.png", size, "", "", 0)
|
||||
assert.NoError(t, err, "Size %s should be valid", size)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid size should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "invalid", "", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("Valid quality values should pass", func(t *testing.T) {
|
||||
validQualities := []string{"low", "medium", "high", "auto"}
|
||||
for _, quality := range validQualities {
|
||||
err := validateImageParameters("/tmp/test.png", "", quality, "", 0)
|
||||
assert.NoError(t, err, "Quality %s should be valid", quality)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid quality should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "invalid", "", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image quality")
|
||||
})
|
||||
|
||||
t.Run("Valid background values should pass", func(t *testing.T) {
|
||||
validBackgrounds := []string{"opaque", "transparent"}
|
||||
for _, background := range validBackgrounds {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", background, 0)
|
||||
assert.NoError(t, err, "Background %s should be valid", background)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid background should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "invalid", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid image background")
|
||||
})
|
||||
|
||||
t.Run("Compression for JPEG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 75)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "", 50)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Compression for PNG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "", 75)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression can only be used with JPEG and WebP formats")
|
||||
})
|
||||
|
||||
t.Run("Invalid compression range should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "", 150)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
|
||||
err = validateImageParameters("/tmp/test.jpg", "", "", "", -10)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "image compression must be between 0 and 100")
|
||||
})
|
||||
|
||||
t.Run("Transparent background for PNG should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.png", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for WebP should pass", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.webp", "", "", "transparent", 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Transparent background for JPEG should fail", func(t *testing.T) {
|
||||
err := validateImageParameters("/tmp/test.jpg", "", "", "transparent", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transparent background can only be used with PNG and WebP formats")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsWithImageParameters(t *testing.T) {
|
||||
t.Run("Valid image parameters should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not set for PNG
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, "/tmp/test.png", options.ImageFile)
|
||||
assert.Equal(t, "1024x1024", options.ImageSize)
|
||||
assert.Equal(t, "high", options.ImageQuality)
|
||||
assert.Equal(t, "transparent", options.ImageBackground)
|
||||
assert.Equal(t, 0, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Invalid image parameters should fail", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "invalid",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "invalid image size")
|
||||
})
|
||||
|
||||
t.Run("JPEG with compression should pass", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 80,
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, options)
|
||||
assert.Equal(t, 80, options.ImageCompression)
|
||||
})
|
||||
|
||||
t.Run("Image parameters without image file should fail in BuildChatOptions", func(t *testing.T) {
|
||||
flags := &Flags{
|
||||
ImageSize: "1024x1024", // Image parameter without ImageFile
|
||||
}
|
||||
|
||||
options, err := flags.BuildChatOptions()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, options)
|
||||
assert.Contains(t, err.Error(), "image parameters")
|
||||
assert.Contains(t, err.Error(), "can only be used with --image-file")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package common
|
||||
|
||||
import goopenai "github.com/sashabaranov/go-openai"
|
||||
import "github.com/danielmiessler/fabric/chat"
|
||||
|
||||
const ChatMessageRoleMeta = "meta"
|
||||
|
||||
@@ -9,7 +9,7 @@ type ChatRequest struct {
|
||||
SessionName string
|
||||
PatternName string
|
||||
PatternVariables map[string]string
|
||||
Message *goopenai.ChatCompletionMessage
|
||||
Message *chat.ChatCompletionMessage
|
||||
Language string
|
||||
Meta string
|
||||
InputHasVars bool
|
||||
@@ -26,10 +26,17 @@ type ChatOptions struct {
|
||||
Seed int
|
||||
ModelContextLength int
|
||||
MaxTokens int
|
||||
Search bool
|
||||
SearchLocation string
|
||||
ImageFile string
|
||||
ImageSize string
|
||||
ImageQuality string
|
||||
ImageCompression int
|
||||
ImageBackground string
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
func NormalizeMessages(msgs []*goopenai.ChatCompletionMessage, defaultUserMessage string) (ret []*goopenai.ChatCompletionMessage) {
|
||||
func NormalizeMessages(msgs []*chat.ChatCompletionMessage, defaultUserMessage string) (ret []*chat.ChatCompletionMessage) {
|
||||
// Iterate over messages to enforce the odd position rule for user messages
|
||||
fullMessageIndex := 0
|
||||
for _, message := range msgs {
|
||||
@@ -39,8 +46,8 @@ func NormalizeMessages(msgs []*goopenai.ChatCompletionMessage, defaultUserMessag
|
||||
}
|
||||
|
||||
// Ensure, that each odd position shall be a user message
|
||||
if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
|
||||
ret = append(ret, &goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
|
||||
if fullMessageIndex%2 == 0 && message.Role != chat.ChatMessageRoleUser {
|
||||
ret = append(ret, &chat.ChatCompletionMessage{Role: chat.ChatMessageRoleUser, Content: defaultUserMessage})
|
||||
fullMessageIndex++
|
||||
}
|
||||
ret = append(ret, message)
|
||||
|
||||
@@ -3,23 +3,23 @@ package common
|
||||
import (
|
||||
"testing"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNormalizeMessages(t *testing.T) {
|
||||
msgs := []*goopenai.ChatCompletionMessage{
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: ""},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: chat.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: chat.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: chat.ChatMessageRoleUser, Content: ""},
|
||||
{Role: chat.ChatMessageRoleUser, Content: ""},
|
||||
{Role: chat.ChatMessageRoleUser, Content: "How are you?"},
|
||||
}
|
||||
|
||||
expected := []*goopenai.ChatCompletionMessage{
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
|
||||
expected := []*chat.ChatCompletionMessage{
|
||||
{Role: chat.ChatMessageRoleUser, Content: "Hello"},
|
||||
{Role: chat.ChatMessageRoleAssistant, Content: "Hi there!"},
|
||||
{Role: chat.ChatMessageRoleUser, Content: "How are you?"},
|
||||
}
|
||||
|
||||
actual := NormalizeMessages(msgs, "default")
|
||||
|
||||
124
common/oauth_storage.go
Normal file
124
common/oauth_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthToken represents stored OAuth token information
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the token is expired or will expire within the buffer time
|
||||
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
|
||||
if t.ExpiresAt == 0 {
|
||||
return true
|
||||
}
|
||||
bufferTime := time.Duration(bufferMinutes) * time.Minute
|
||||
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
|
||||
}
|
||||
|
||||
// OAuthStorage handles persistent storage of OAuth tokens
|
||||
type OAuthStorage struct {
|
||||
configDir string
|
||||
}
|
||||
|
||||
// NewOAuthStorage creates a new OAuth storage instance
|
||||
func NewOAuthStorage() (*OAuthStorage, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".config", "fabric")
|
||||
|
||||
// Ensure config directory exists
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return &OAuthStorage{configDir: configDir}, nil
|
||||
}
|
||||
|
||||
// GetTokenPath returns the file path for a provider's OAuth token
|
||||
func (s *OAuthStorage) GetTokenPath(provider string) string {
|
||||
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
|
||||
}
|
||||
|
||||
// SaveToken saves an OAuth token to disk with proper permissions
|
||||
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Marshal token to JSON
|
||||
data, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first for atomic operation
|
||||
tempPath := tokenPath + ".tmp"
|
||||
if err := os.WriteFile(tempPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, tokenPath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to save token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToken loads an OAuth token from disk
|
||||
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
return nil, nil // No token stored
|
||||
}
|
||||
|
||||
// Read token file
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal token
|
||||
var token OAuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken removes a stored OAuth token
|
||||
func (s *OAuthStorage) DeleteToken(provider string) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasValidToken checks if a valid (non-expired) token exists for a provider
|
||||
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
|
||||
token, err := s.LoadToken(provider)
|
||||
if err != nil || token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !token.IsExpired(bufferMinutes)
|
||||
}
|
||||
232
common/oauth_storage_test.go
Normal file
232
common/oauth_storage_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOAuthToken_IsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt int64
|
||||
bufferMinutes int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "token not expired",
|
||||
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
bufferMinutes: 5,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "token expired",
|
||||
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "token expires within buffer",
|
||||
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry time",
|
||||
expiresAt: 0,
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &OAuthToken{ExpiresAt: tt.expiresAt}
|
||||
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create storage with custom config dir
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_access_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "test_scope",
|
||||
}
|
||||
|
||||
// Test saving token
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and has correct permissions
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
info, err := os.Stat(tokenPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Token file not created: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
|
||||
}
|
||||
|
||||
// Test loading token
|
||||
loadedToken, err := storage.LoadToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load token: %v", err)
|
||||
}
|
||||
if loadedToken == nil {
|
||||
t.Fatal("Loaded token is nil")
|
||||
}
|
||||
|
||||
// Verify token data
|
||||
if loadedToken.AccessToken != token.AccessToken {
|
||||
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
if loadedToken.RefreshToken != token.RefreshToken {
|
||||
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
|
||||
}
|
||||
if loadedToken.ExpiresAt != token.ExpiresAt {
|
||||
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Try to load non-existent token
|
||||
token, err := storage.LoadToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error loading non-existent token: %v", err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Error("Expected nil token for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_DeleteToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Create and save a token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
}
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token exists
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
t.Fatal("Token file should exist before deletion")
|
||||
}
|
||||
|
||||
// Delete token
|
||||
err = storage.DeleteToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token is deleted
|
||||
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
|
||||
t.Error("Token file should not exist after deletion")
|
||||
}
|
||||
|
||||
// Test deleting non-existent token (should not error)
|
||||
err = storage.DeleteToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Deleting non-existent token should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_HasValidToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test with no token
|
||||
if storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return false when no token exists")
|
||||
}
|
||||
|
||||
// Save valid token
|
||||
validToken := &OAuthToken{
|
||||
AccessToken: "valid_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
}
|
||||
err = storage.SaveToken("test_provider", validToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save valid token: %v", err)
|
||||
}
|
||||
|
||||
// Test with valid token
|
||||
if !storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return true for valid token")
|
||||
}
|
||||
|
||||
// Save expired token
|
||||
expiredToken := &OAuthToken{
|
||||
AccessToken: "expired_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
}
|
||||
err = storage.SaveToken("expired_provider", expiredToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save expired token: %v", err)
|
||||
}
|
||||
|
||||
// Test with expired token
|
||||
if storage.HasValidToken("expired_provider", 5) {
|
||||
t.Error("Should return false for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_GetTokenPath(t *testing.T) {
|
||||
storage := &OAuthStorage{configDir: "/test/config"}
|
||||
|
||||
expected := filepath.Join("/test/config", ".test_provider_oauth")
|
||||
actual := storage.GetTokenPath("test_provider")
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
@@ -96,6 +96,13 @@ _fabric() {
|
||||
'(--api-key)--api-key[API key used to secure server routes]:api-key:' \
|
||||
'(--config)--config[Path to YAML config file]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--version)--version[Print current version]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \
|
||||
'(--image-quality)--image-quality[Image quality]:quality:(low medium high auto)' \
|
||||
'(--image-compression)--image-compression[Compression level 0-100 for JPEG/WebP formats]:compression:' \
|
||||
'(--image-background)--image-background[Background type]:background:(opaque transparent)' \
|
||||
'(--listextensions)--listextensions[List all registered extensions]' \
|
||||
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \
|
||||
|
||||
@@ -13,7 +13,7 @@ _fabric() {
|
||||
_get_comp_words_by_ref -n : cur prev words cword
|
||||
|
||||
# Define all possible options/flags
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
local opts="--pattern -p --variable -v --context -C --session --attachment -a --setup -S --temperature -t --topp -T --stream -s --presencepenalty -P --raw -r --frequencypenalty -F --listpatterns -l --listmodels -L --listcontexts -x --listsessions -X --updatepatterns -U --copy -c --model -m --modelContextLength --output -o --output-session --latest -n --changeDefaultModel -d --youtube -y --playlist --transcript --transcript-with-timestamps --comments --metadata --language -g --scrape_url -u --scrape_question -q --seed -e --wipecontext -w --wipesession -W --printcontext --printsession --readability --input-has-vars --dry-run --serve --serveOllama --address --api-key --config --search --search-location --image-file --image-size --image-quality --image-compression --image-background --version --listextensions --addextension --rmextension --strategy --liststrategies --listvendors --shell-complete-list --help -h"
|
||||
|
||||
# Helper function for dynamic completions
|
||||
_fabric_get_list() {
|
||||
@@ -63,12 +63,25 @@ _fabric() {
|
||||
return 0
|
||||
;;
|
||||
# Options requiring file/directory paths
|
||||
-a | --attachment | -o | --output | --config | --addextension)
|
||||
-a | --attachment | -o | --output | --config | --addextension | --image-file)
|
||||
_filedir
|
||||
return 0
|
||||
;;
|
||||
# Image generation options with specific values
|
||||
--image-size)
|
||||
COMPREPLY=($(compgen -W "1024x1024 1536x1024 1024x1536 auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-quality)
|
||||
COMPREPLY=($(compgen -W "low medium high auto" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
--image-background)
|
||||
COMPREPLY=($(compgen -W "opaque transparent" -- "$cur"))
|
||||
return 0
|
||||
;;
|
||||
# Options requiring simple arguments (no specific completion logic here)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key)
|
||||
-v | --variable | -t | --temperature | -T | --topp | -P | --presencepenalty | -F | --frequencypenalty | --modelContextLength | -n | --latest | -y | --youtube | -g | --language | -u | --scrape_url | -q | --scrape_question | -e | --seed | --address | --api-key | --search-location | --image-compression)
|
||||
# No specific completion suggestions, user types the value
|
||||
return 0
|
||||
;;
|
||||
|
||||
@@ -60,6 +60,12 @@ complete -c fabric -l printsession -d "Print session" -a "(__fabric_get_sessions
|
||||
complete -c fabric -l address -d "The address to bind the REST API (default: :8080)"
|
||||
complete -c fabric -l api-key -d "API key used to secure server routes"
|
||||
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
|
||||
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
|
||||
complete -c fabric -l image-size -d "Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)" -a "1024x1024 1536x1024 1024x1536 auto"
|
||||
complete -c fabric -l image-quality -d "Image quality: low, medium, high, auto (default: auto)" -a "low medium high auto"
|
||||
complete -c fabric -l image-compression -d "Compression level 0-100 for JPEG/WebP formats (default: not set)" -r
|
||||
complete -c fabric -l image-background -d "Background type: opaque, transparent (default: opaque, only for PNG/WebP)" -a "opaque transparent"
|
||||
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
|
||||
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
|
||||
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"
|
||||
@@ -84,6 +90,7 @@ complete -c fabric -l metadata -d "Output video metadata"
|
||||
complete -c fabric -l readability -d "Convert HTML input into a clean, readable view"
|
||||
complete -c fabric -l input-has-vars -d "Apply variables to user input"
|
||||
complete -c fabric -l dry-run -d "Show what would be sent to the model without actually sending it"
|
||||
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI)"
|
||||
complete -c fabric -l serve -d "Serve the Fabric Rest API"
|
||||
complete -c fabric -l serveOllama -d "Serve the Fabric Rest API with ollama endpoints"
|
||||
complete -c fabric -l version -d "Print current version"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins/ai"
|
||||
@@ -110,7 +110,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (s
|
||||
message = summary
|
||||
}
|
||||
|
||||
session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleAssistant, Content: message})
|
||||
session.Append(&chat.ChatCompletionMessage{Role: chat.ChatMessageRoleAssistant, Content: message})
|
||||
|
||||
if session.Name != "" {
|
||||
err = o.db.Sessions.SaveSession(session)
|
||||
@@ -131,7 +131,7 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session *
|
||||
}
|
||||
|
||||
if request.Meta != "" {
|
||||
session.Append(&goopenai.ChatCompletionMessage{Role: common.ChatMessageRoleMeta, Content: request.Meta})
|
||||
session.Append(&chat.ChatCompletionMessage{Role: common.ChatMessageRoleMeta, Content: request.Meta})
|
||||
}
|
||||
|
||||
// if a context name is provided, retrieve it from the database
|
||||
@@ -149,8 +149,8 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session *
|
||||
// Double curly braces {{variable}} indicate template substitution
|
||||
// Ensure we have a message before processing
|
||||
if request.Message == nil {
|
||||
request.Message = &goopenai.ChatCompletionMessage{
|
||||
Role: goopenai.ChatMessageRoleUser,
|
||||
request.Message = &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
Content: " ",
|
||||
}
|
||||
}
|
||||
@@ -206,26 +206,26 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session *
|
||||
// Handle MultiContent properly in raw mode
|
||||
if len(request.Message.MultiContent) > 0 {
|
||||
// When we have attachments, add the text as a text part in MultiContent
|
||||
newMultiContent := []goopenai.ChatMessagePart{
|
||||
newMultiContent := []chat.ChatMessagePart{
|
||||
{
|
||||
Type: goopenai.ChatMessagePartTypeText,
|
||||
Type: chat.ChatMessagePartTypeText,
|
||||
Text: finalContent,
|
||||
},
|
||||
}
|
||||
// Add existing non-text parts (like images)
|
||||
for _, part := range request.Message.MultiContent {
|
||||
if part.Type != goopenai.ChatMessagePartTypeText {
|
||||
if part.Type != chat.ChatMessagePartTypeText {
|
||||
newMultiContent = append(newMultiContent, part)
|
||||
}
|
||||
}
|
||||
request.Message = &goopenai.ChatCompletionMessage{
|
||||
Role: goopenai.ChatMessageRoleUser,
|
||||
request.Message = &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
MultiContent: newMultiContent,
|
||||
}
|
||||
} else {
|
||||
// No attachments, use regular Content field
|
||||
request.Message = &goopenai.ChatCompletionMessage{
|
||||
Role: goopenai.ChatMessageRoleUser,
|
||||
request.Message = &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
Content: finalContent,
|
||||
}
|
||||
}
|
||||
@@ -235,7 +235,7 @@ func (o *Chatter) BuildSession(request *common.ChatRequest, raw bool) (session *
|
||||
}
|
||||
} else {
|
||||
if systemMessage != "" {
|
||||
session.Append(&goopenai.ChatCompletionMessage{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
|
||||
session.Append(&chat.ChatCompletionMessage{Role: chat.ChatMessageRoleSystem, Content: systemMessage})
|
||||
}
|
||||
// If multi-part content, it is in the user message, and should be added.
|
||||
// Otherwise, we should only add it if we have not already used it in the systemMessage.
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins/db/fsdb"
|
||||
"github.com/danielmiessler/fabric/plugins/template"
|
||||
"github.com/danielmiessler/fabric/plugins/tools"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/custom_patterns"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/jina"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/lang"
|
||||
"github.com/danielmiessler/fabric/plugins/tools/youtube"
|
||||
@@ -69,6 +70,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
|
||||
VendorManager: ai.NewVendorsManager(),
|
||||
VendorsAll: ai.NewVendorsManager(),
|
||||
PatternsLoader: tools.NewPatternsLoader(db.Patterns),
|
||||
CustomPatterns: custom_patterns.NewCustomPatterns(),
|
||||
YouTube: youtube.NewYouTube(),
|
||||
Language: lang.NewLanguage(),
|
||||
Jina: jina.NewClient(),
|
||||
@@ -138,6 +140,7 @@ type PluginRegistry struct {
|
||||
VendorsAll *ai.VendorsManager
|
||||
Defaults *tools.Defaults
|
||||
PatternsLoader *tools.PatternsLoader
|
||||
CustomPatterns *custom_patterns.CustomPatterns
|
||||
YouTube *youtube.YouTube
|
||||
Language *lang.Language
|
||||
Jina *jina.Client
|
||||
@@ -151,6 +154,7 @@ func (o *PluginRegistry) SaveEnvFile() (err error) {
|
||||
|
||||
o.Defaults.Settings.FillEnvFileContent(&envFileContent)
|
||||
o.PatternsLoader.SetupFillEnvFileContent(&envFileContent)
|
||||
o.CustomPatterns.SetupFillEnvFileContent(&envFileContent)
|
||||
o.Strategies.SetupFillEnvFileContent(&envFileContent)
|
||||
|
||||
for _, vendor := range o.VendorManager.Vendors {
|
||||
@@ -183,7 +187,7 @@ func (o *PluginRegistry) Setup() (err error) {
|
||||
return vendor
|
||||
})...)
|
||||
|
||||
groupsPlugins.AddGroupItems("Tools", o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
groupsPlugins.AddGroupItems("Tools", o.CustomPatterns, o.Defaults, o.Jina, o.Language, o.PatternsLoader, o.Strategies, o.YouTube)
|
||||
|
||||
for {
|
||||
groupsPlugins.Print(false)
|
||||
|
||||
5
go.mod
5
go.mod
@@ -19,15 +19,15 @@ require (
|
||||
github.com/jessevdk/go-flags v1.6.1
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/ollama/ollama v0.9.0
|
||||
github.com/openai/openai-go v1.8.2
|
||||
github.com/otiai10/copy v1.14.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/samber/lo v1.50.0
|
||||
github.com/sashabaranov/go-openai v1.40.3
|
||||
github.com/sgaunet/perplexity-go/v2 v2.8.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -109,7 +109,6 @@ require (
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
|
||||
5
go.sum
5
go.sum
@@ -172,6 +172,8 @@ github.com/ollama/ollama v0.9.0 h1:GvdGhi8G/QMnFrY0TMLDy1bXua+Ify8KTkFe4ZY/OZs=
|
||||
github.com/ollama/ollama v0.9.0/go.mod h1:aio9yQ7nc4uwIbn6S0LkGEPgn8/9bNQLL1nHuH+OcD0=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8=
|
||||
github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/otiai10/copy v1.14.1 h1:5/7E6qsUMBaH5AnQ0sSLzzTg1oTECmcCmT6lvF45Na8=
|
||||
github.com/otiai10/copy v1.14.1/go.mod h1:oQwrEDDOci3IM8dJF0d8+jnbfPDllW6vUjNc3DoZm9I=
|
||||
github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs=
|
||||
@@ -189,8 +191,6 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/samber/lo v1.50.0 h1:XrG0xOeHs+4FQ8gJR97zDz5uOFMW7OwFWiFVzqopKgY=
|
||||
github.com/samber/lo v1.50.0/go.mod h1:RjZyNk6WSnUFRKK6EyOhsRJMqft3G+pg7dCWHQCWvsc=
|
||||
github.com/sashabaranov/go-openai v1.40.3 h1:PkOw0SK34wrvYVOuXF1HZzuTBRh992qRZHil4kG3eYE=
|
||||
github.com/sashabaranov/go-openai v1.40.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
|
||||
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||
@@ -354,7 +354,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -208,6 +208,9 @@ schema = 3
|
||||
[mod."github.com/ollama/ollama"]
|
||||
version = "v0.9.0"
|
||||
hash = "sha256-r2eU+kMG3tuJy2B43RXsfmeltzM9t05NEmNiJAW5qr4="
|
||||
[mod."github.com/openai/openai-go"]
|
||||
version = "v1.8.2"
|
||||
hash = "sha256-O8aV3zEj6o8kIlzlkYaTW4RzvwR3qNUBYiN8SuTM1R0="
|
||||
[mod."github.com/otiai10/copy"]
|
||||
version = "v1.14.1"
|
||||
hash = "sha256-8RR7u17SbYg9AeBXVHIv5ZMU+kHmOcx0rLUKyz6YtU0="
|
||||
@@ -229,9 +232,6 @@ schema = 3
|
||||
[mod."github.com/samber/lo"]
|
||||
version = "v1.50.0"
|
||||
hash = "sha256-KDFks82BKu39sGt0f972IyOkohV2U0r1YvsnlNLdugY="
|
||||
[mod."github.com/sashabaranov/go-openai"]
|
||||
version = "v1.40.3"
|
||||
hash = "sha256-Q2+la99lgKwcpgGHuf5p23gH5hmhYKp77YDUtbt35eo="
|
||||
[mod."github.com/sergi/go-diff"]
|
||||
version = "v1.4.0"
|
||||
hash = "sha256-rs9NKpv/qcQEMRg7CmxGdP4HGuFdBxlpWf9LbA9wS4k="
|
||||
@@ -325,9 +325,6 @@ schema = 3
|
||||
[mod."gopkg.in/warnings.v0"]
|
||||
version = "v0.1.2"
|
||||
hash = "sha256-ATVL9yEmgYbkJ1DkltDGRn/auGAjqGOfjQyBYyUo8s8="
|
||||
[mod."gopkg.in/yaml.v2"]
|
||||
version = "v2.4.0"
|
||||
hash = "sha256-uVEGglIedjOIGZzHW4YwN1VoRSTK8o0eGZqzd+TNdd0="
|
||||
[mod."gopkg.in/yaml.v3"]
|
||||
version = "v3.0.1"
|
||||
hash = "sha256-FqL9TKYJ0XkNwJFnq9j0VvJ5ZUU1RvH/52h/f5bkYAU="
|
||||
|
||||
@@ -1 +1 @@
|
||||
"1.4.220"
|
||||
"1.4.235"
|
||||
|
||||
@@ -22,19 +22,20 @@ Take a deep breath and think step by step about how to best accomplish this goal
|
||||
This must be under the heading "INSIGHTFULNESS SCORE (0 = not very interesting and insightful to 10 = very interesting and insightful)".
|
||||
- A rating of how emotional the debate was from 0 (very calm) to 5 (very emotional). This must be under the heading "EMOTIONALITY SCORE (0 (very calm) to 5 (very emotional))".
|
||||
- A list of the participants of the debate and a score of their emotionality from 0 (very calm) to 5 (very emotional). This must be under the heading "PARTICIPANTS".
|
||||
- A list of arguments attributed to participants with names and quotes. If possible, this should include external references that disprove or back up their claims.
|
||||
- A list of arguments attributed to participants with names and quotes. Each argument summary must be EXACTLY 16 words. If possible, this should include external references that disprove or back up their claims.
|
||||
It is IMPORTANT that these references are from trusted and verifiable sources that can be easily accessed. These sources have to BE REAL and NOT MADE UP. This must be under the heading "ARGUMENTS".
|
||||
If possible, provide an objective assessment of the truth of these arguments. If you assess the truth of the argument, provide some sources that back up your assessment. The material you provide should be from reliable, verifiable, and trustworthy sources. DO NOT MAKE UP SOURCES.
|
||||
- A list of agreements the participants have reached, attributed with names and quotes. This must be under the heading "AGREEMENTS".
|
||||
- A list of disagreements the participants were unable to resolve and the reasons why they remained unresolved, attributed with names and quotes. This must be under the heading "DISAGREEMENTS".
|
||||
- A list of possible misunderstandings and why they may have occurred, attributed with names and quotes. This must be under the heading "POSSIBLE MISUNDERSTANDINGS".
|
||||
- A list of learnings from the debate. This must be under the heading "LEARNINGS".
|
||||
- A list of takeaways that highlight ideas to think about, sources to explore, and actionable items. This must be under the heading "TAKEAWAYS".
|
||||
- A list of agreements the participants have reached. Each agreement summary must be EXACTLY 16 words, followed by names and quotes. This must be under the heading "AGREEMENTS".
|
||||
- A list of disagreements the participants were unable to resolve. Each disagreement summary must be EXACTLY 16 words, followed by names and quotes explaining why they remained unresolved. This must be under the heading "DISAGREEMENTS".
|
||||
- A list of possible misunderstandings. Each misunderstanding summary must be EXACTLY 16 words, followed by names and quotes explaining why they may have occurred. This must be under the heading "POSSIBLE MISUNDERSTANDINGS".
|
||||
- A list of learnings from the debate. Each learning must be EXACTLY 16 words. This must be under the heading "LEARNINGS".
|
||||
- A list of takeaways that highlight ideas to think about, sources to explore, and actionable items. Each takeaway must be EXACTLY 16 words. This must be under the heading "TAKEAWAYS".
|
||||
|
||||
# OUTPUT INSTRUCTIONS
|
||||
|
||||
- Output all sections above.
|
||||
- Use Markdown to structure your output.
|
||||
- Do not use any markdown formatting (no asterisks, no bullet points, no headers).
|
||||
- Keep all agreements, arguments, recommendations, learnings, and takeaways to EXACTLY 16 words each.
|
||||
- When providing quotes, these quotes should clearly express the points you are using them for. If necessary, use multiple quotes.
|
||||
|
||||
# INPUT:
|
||||
|
||||
49
patterns/apply_ul_tags/system.md
Normal file
49
patterns/apply_ul_tags/system.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# IDENTITY
|
||||
|
||||
You are a superintelligent expert on content of all forms, with deep understanding of which topics, categories, themes, and tags apply to any piece of content.
|
||||
|
||||
# GOAL
|
||||
|
||||
Your goal is to output a JSON object called tags, with the following tags applied if the content is significantly about their topic.
|
||||
|
||||
- **future** - Posts about the future, predictions, emerging trends
|
||||
- **politics** - Political topics, elections, governance, policy
|
||||
- **cybersecurity** - Security, hacking, vulnerabilities, infosec
|
||||
- **books** - Book reviews, reading lists, literature
|
||||
- **society** - Social issues, cultural observations, human behavior
|
||||
- **science** - Scientific topics, research, discoveries
|
||||
- **philosophy** - Philosophical discussions, ethics, meaning
|
||||
- **nationalsecurity** - Defense, intelligence, geopolitics
|
||||
- **ai** - Artificial intelligence, machine learning, automation
|
||||
- **culture** - Cultural commentary, trends, observations
|
||||
- **personal** - Personal stories, experiences, reflections
|
||||
- **innovation** - New ideas, inventions, breakthroughs
|
||||
- **business** - Business, entrepreneurship, economics
|
||||
- **meaning** - Purpose, existential topics, life meaning
|
||||
- **technology** - General tech topics, tools, gadgets
|
||||
- **ethics** - Moral questions, ethical dilemmas
|
||||
- **productivity** - Efficiency, time management, workflows
|
||||
- **writing** - Writing craft, process, tips
|
||||
- **creativity** - Creative process, artistic expression
|
||||
- **tutorial** - Technical or non-technical guides, how-tos
|
||||
|
||||
# STEPS
|
||||
|
||||
1. Deeply understand the content and its themes and categories and topics.
|
||||
2. Evaluate the list of tags above.
|
||||
3. Determine which tags apply to the content.
|
||||
4. Output the "tags" JSON object.
|
||||
|
||||
# NOTES
|
||||
|
||||
- It's ok, and quite normal, for multiple tags to apply—which is why this is tags and not categories
|
||||
- All AI posts should have the technology tag, and that's ok. But not all technology posts are about AI, and therefore the AI tag needs to be evaluated separately. That goes for all potentially nested or conflicted tags.
|
||||
- Be a bit conservative in applying tags. If a piece of content is only tangentially related to a tag, don't include it.
|
||||
|
||||
# OUTPUT INSTRUCTIONS
|
||||
|
||||
- Output ONLY the JSON object, and nothing else.
|
||||
|
||||
- That means DO NOT OUTPUT the ```json format indicator. ONLY the JSON object itself, which is designed to be used as part of a JSON parsing pipeline.
|
||||
|
||||
|
||||
16
patterns/extract_alpha/system.md
Normal file
16
patterns/extract_alpha/system.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# IDENTITY
|
||||
|
||||
You're an expert at finding Alpha in content.
|
||||
|
||||
# PHILOSOPHY
|
||||
|
||||
I love the idea of Claude Shannon's information theory where basically the only real information is the stuff that's different and anything that's the same as kind of background noise.
|
||||
|
||||
I love that idea for novelty and surprise inside of content when I think about a presentation or a talk or a podcast or an essay or anything I'm looking for the net new ideas or the new presentation of ideas for the new frameworks of how to use ideas or combine ideas so I'm looking for a way to capture that inside of content.
|
||||
|
||||
# INSTRUCTIONS
|
||||
|
||||
I want you to extract the 24 highest alpha ideas and thoughts and insights and recommendations in this piece of content, and I want you to output them in unformatted marked down in 8-word bullets written in the approachable style of Paul Graham.
|
||||
|
||||
# INPUT
|
||||
|
||||
140
patterns/review_code/system.md
Normal file
140
patterns/review_code/system.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Code Review Task
|
||||
|
||||
## ROLE AND GOAL
|
||||
|
||||
You are a Principal Software Engineer, renowned for your meticulous attention to detail and your ability to provide clear, constructive, and educational code reviews. Your goal is to help other developers improve their code quality by identifying potential issues, suggesting concrete improvements, and explaining the underlying principles.
|
||||
|
||||
## TASK
|
||||
|
||||
You will be given a snippet of code or a diff. Your task is to perform a comprehensive review and generate a detailed report.
|
||||
|
||||
## STEPS
|
||||
|
||||
1. **Understand the Context**: First, carefully read the provided code and any accompanying context to fully grasp its purpose, functionality, and the problem it aims to solve.
|
||||
2. **Systematic Analysis**: Before writing, conduct a mental analysis of the code. Evaluate it against the following key aspects. Do not write this analysis in the output; use it to form your review.
|
||||
* **Correctness**: Are there bugs, logic errors, or race conditions?
|
||||
* **Security**: Are there any potential vulnerabilities (e.g., injection attacks, improper handling of sensitive data)?
|
||||
* **Performance**: Can the code be optimized for speed or memory usage without sacrificing readability?
|
||||
* **Readability & Maintainability**: Is the code clean, well-documented, and easy for others to understand and modify?
|
||||
* **Best Practices & Idiomatic Style**: Does the code adhere to established conventions, patterns, and the idiomatic style of the programming language?
|
||||
* **Error Handling & Edge Cases**: Are errors handled gracefully? Have all relevant edge cases been considered?
|
||||
3. **Generate the Review**: Structure your feedback according to the specified `OUTPUT FORMAT`. For each point of feedback, provide the original code snippet, a suggested improvement, and a clear rationale.
|
||||
|
||||
## OUTPUT FORMAT
|
||||
|
||||
Your review must be in Markdown and follow this exact structure:
|
||||
|
||||
---
|
||||
|
||||
### Overall Assessment
|
||||
|
||||
A brief, high-level summary of the code's quality. Mention its strengths and the primary areas for improvement.
|
||||
|
||||
### **Prioritized Recommendations**
|
||||
|
||||
A numbered list of the most important changes, ordered from most to least critical.
|
||||
|
||||
1. (Most critical change)
|
||||
2. (Second most critical change)
|
||||
3. ...
|
||||
|
||||
### **Detailed Feedback**
|
||||
|
||||
For each issue you identified, provide a detailed breakdown in the following format.
|
||||
|
||||
---
|
||||
|
||||
**[ISSUE TITLE]** - (e.g., `Security`, `Readability`, `Performance`)
|
||||
|
||||
**Original Code:**
|
||||
|
||||
```[language]
|
||||
// The specific lines of code with the issue
|
||||
```
|
||||
|
||||
**Suggested Improvement:**
|
||||
|
||||
```[language]
|
||||
// The revised, improved code
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
A clear and concise explanation of why the change is recommended. Reference best practices, design patterns, or potential risks. If you use advanced concepts, briefly explain them.
|
||||
|
||||
---
|
||||
(Repeat this section for each issue)
|
||||
|
||||
## EXAMPLE
|
||||
|
||||
Here is an example of a review for a simple Python function:
|
||||
|
||||
---
|
||||
|
||||
### **Overall Assessment**
|
||||
|
||||
The function correctly fetches user data, but it can be made more robust and efficient. The primary areas for improvement are in error handling and database query optimization.
|
||||
|
||||
### **Prioritized Recommendations**
|
||||
|
||||
1. Avoid making database queries inside a loop to prevent performance issues (N+1 query problem).
|
||||
2. Add specific error handling for when a user is not found.
|
||||
|
||||
### **Detailed Feedback**
|
||||
|
||||
---
|
||||
|
||||
**[PERFORMANCE]** - N+1 Database Query
|
||||
|
||||
**Original Code:**
|
||||
|
||||
```python
|
||||
def get_user_emails(user_ids):
|
||||
emails = []
|
||||
for user_id in user_ids:
|
||||
user = db.query(User).filter(User.id == user_id).one()
|
||||
emails.append(user.email)
|
||||
return emails
|
||||
```
|
||||
|
||||
**Suggested Improvement:**
|
||||
|
||||
```python
|
||||
def get_user_emails(user_ids):
|
||||
if not user_ids:
|
||||
return []
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.email for user in users]
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
The original code executes one database query for each `user_id` in the list. This is known as the "N+1 query problem" and performs very poorly on large lists. The suggested improvement fetches all users in a single query using `IN`, which is significantly more efficient.
|
||||
|
||||
---
|
||||
|
||||
**[CORRECTNESS]** - Lacks Specific Error Handling
|
||||
|
||||
**Original Code:**
|
||||
|
||||
```python
|
||||
user = db.query(User).filter(User.id == user_id).one()
|
||||
```
|
||||
|
||||
**Suggested Improvement:**
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).one()
|
||||
except NoResultFound:
|
||||
# Handle the case where the user doesn't exist
|
||||
# e.g., log a warning, skip the user, or raise a custom exception
|
||||
continue
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
The `.one()` method will raise a `NoResultFound` exception if a user with the given ID doesn't exist, which would crash the entire function. It's better to explicitly handle this case using a try/except block to make the function more resilient.
|
||||
|
||||
---
|
||||
|
||||
## INPUT
|
||||
@@ -3,19 +3,22 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
const defaultBaseUrl = "https://api.anthropic.com/"
|
||||
|
||||
const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
ret = &Client{}
|
||||
@@ -28,10 +31,12 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseWebTool = ret.AddSetupQuestionBool("Web Search Tool Enabled", false)
|
||||
ret.WebToolLocation = ret.AddSetupQuestionCustom("Web Search Tool Location", false,
|
||||
"Enter your approximate timezone location for web search (e.g., 'America/Los_Angeles', see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).")
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
if plugins.ParseBoolElseFalse(ret.UseOAuth.Value) {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
} else {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
}
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -49,10 +54,9 @@ func NewClient() (ret *Client) {
|
||||
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseWebTool *plugins.SetupQuestion
|
||||
WebToolLocation *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -61,24 +65,50 @@ type Client struct {
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
baseURL := an.ApiBaseURL.Value
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// As of 2.0beta1, using v2 API endpoint.
|
||||
// https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md#020-beta1-2025-03-25
|
||||
if strings.Contains(baseURL, "-") && !strings.HasSuffix(baseURL, "/v2") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
baseURL = baseURL + "/v2"
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(
|
||||
option.WithAPIKey(an.ApiKey.Value),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
} else {
|
||||
an.client = anthropic.NewClient(option.WithAPIKey(an.ApiKey.Value))
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
opts := []option.RequestOption{}
|
||||
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +117,7 @@ func (an *Client) ListModels() (ret []string, err error) {
|
||||
}
|
||||
|
||||
func (an *Client) SendStream(
|
||||
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
messages := an.toMessages(msgs)
|
||||
if len(messages) == 0 {
|
||||
@@ -127,20 +157,28 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseWebTool.Value) {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
Name: "web_search", // string literal instead of constant
|
||||
Type: "web_search_20250305", // string literal instead of constant
|
||||
CacheControl: anthropic.NewCacheControlEphemeralParam(),
|
||||
// Optional: restrict domains or max uses
|
||||
// AllowedDomains: []string{"wikipedia.org", "openai.com"},
|
||||
// MaxUses: anthropic.Opt[int64](5),
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
if an.WebToolLocation.Value != "" {
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
Name: webSearchToolName,
|
||||
Type: webSearchToolType,
|
||||
CacheControl: anthropic.NewCacheControlEphemeralParam(),
|
||||
}
|
||||
|
||||
if opts.SearchLocation != "" {
|
||||
webTool.UserLocation.Type = "approximate"
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(an.WebToolLocation.Value)
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(opts.SearchLocation)
|
||||
}
|
||||
|
||||
// Wrap it in the union:
|
||||
@@ -151,7 +189,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (
|
||||
func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (
|
||||
ret string, err error) {
|
||||
|
||||
messages := an.toMessages(msgs)
|
||||
@@ -165,18 +203,47 @@ func (an *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessa
|
||||
return
|
||||
}
|
||||
|
||||
texts := lo.FilterMap(message.Content, func(block anthropic.ContentBlockUnion, _ int) (ret string, ok bool) {
|
||||
if ok = block.Type == "text" && block.Text != ""; ok {
|
||||
ret = block.Text
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool) // To avoid duplicate citations
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
// Extract citations from this text block
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := fmt.Sprintf("- [%s](%s)", citation.Title, citation.URL)
|
||||
if citation.CitedText != "" {
|
||||
citationText += fmt.Sprintf(" - \"%s\"", citation.CitedText)
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
ret = strings.Join(texts, "")
|
||||
}
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
resultBuilder.WriteString(strings.Join(textParts, ""))
|
||||
|
||||
// Append citations if any were found
|
||||
if len(citations) > 0 {
|
||||
resultBuilder.WriteString("\n\n")
|
||||
resultBuilder.WriteString(sourcesHeader)
|
||||
resultBuilder.WriteString("\n\n")
|
||||
resultBuilder.WriteString(strings.Join(citations, "\n"))
|
||||
}
|
||||
ret = resultBuilder.String()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anthropic.MessageParam) {
|
||||
func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropic.MessageParam) {
|
||||
// Custom normalization for Anthropic:
|
||||
// - System messages become the first part of the first user message.
|
||||
// - Messages must alternate user/assistant.
|
||||
@@ -184,6 +251,9 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
|
||||
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
var systemContent string
|
||||
|
||||
// Note: Claude Code spoofing is now handled in buildMessageParams
|
||||
|
||||
isFirstUserMessage := true
|
||||
lastRoleWasUser := false
|
||||
|
||||
@@ -193,14 +263,14 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
|
||||
}
|
||||
|
||||
switch msg.Role {
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
case chat.ChatMessageRoleSystem:
|
||||
// Accumulate system content. It will be prepended to the first user message.
|
||||
if systemContent != "" {
|
||||
systemContent += "\\n" + msg.Content
|
||||
} else {
|
||||
systemContent = msg.Content
|
||||
}
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
case chat.ChatMessageRoleUser:
|
||||
userContent := msg.Content
|
||||
if isFirstUserMessage && systemContent != "" {
|
||||
userContent = systemContent + "\\n\\n" + userContent
|
||||
@@ -213,7 +283,7 @@ func (an *Client) toMessages(msgs []*goopenai.ChatCompletionMessage) (ret []anth
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(userContent)))
|
||||
lastRoleWasUser = true
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
// If the first message is an assistant message, and we have system content,
|
||||
// prepend a user message with the system content.
|
||||
if isFirstUserMessage && systemContent != "" {
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
// Test generated using Keploy
|
||||
@@ -63,3 +67,192 @@ func TestClient_ListModels_ReturnsCorrectModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithoutSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: false,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools != nil {
|
||||
t.Error("Expected no tools when search is disabled, got tools")
|
||||
}
|
||||
|
||||
if params.Model != anthropic.Model(opts.Model) {
|
||||
t.Errorf("Expected model %s, got %s", opts.Model, params.Model)
|
||||
}
|
||||
|
||||
if params.Temperature.Value != opts.Temperature {
|
||||
t.Errorf("Expected temperature %f, got %f", opts.Temperature, params.Temperature.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather today?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
if len(params.Tools) != 1 {
|
||||
t.Errorf("Expected 1 tool, got %d", len(params.Tools))
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.Name != "web_search" {
|
||||
t.Errorf("Expected tool name 'web_search', got %s", webTool.Name)
|
||||
}
|
||||
|
||||
if webTool.Type != "web_search_20250305" {
|
||||
t.Errorf("Expected tool type 'web_search_20250305', got %s", webTool.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearchAndLocation(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in San Francisco?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Type != "approximate" {
|
||||
t.Errorf("Expected location type 'approximate', got %s", webTool.UserLocation.Type)
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Timezone.Value != opts.SearchLocation {
|
||||
t.Errorf("Expected timezone %s, got %s", opts.SearchLocation, webTool.UserLocation.Timezone.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
// Test the citation formatting logic by creating a mock message with citations
|
||||
message := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "Based on recent research, artificial intelligence is advancing rapidly.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research",
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "artificial intelligence is advancing rapidly",
|
||||
},
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://another-source.com/tech-news",
|
||||
Title: "Technology News Today",
|
||||
CitedText: "recent developments in AI",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: " Machine learning models are becoming more sophisticated.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research", // Duplicate URL should be deduplicated
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "machine learning models",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract text and citations using the same logic as the Send method
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool)
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
|
||||
if citation.CitedText != "" {
|
||||
citationText += " - \"" + citation.CitedText + "\""
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(textParts, "")
|
||||
if len(citations) > 0 {
|
||||
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
// Verify the result contains the expected text
|
||||
expectedText := "Based on recent research, artificial intelligence is advancing rapidly. Machine learning models are becoming more sophisticated."
|
||||
if !strings.Contains(result, expectedText) {
|
||||
t.Errorf("Expected result to contain text: %s", expectedText)
|
||||
}
|
||||
|
||||
// Verify citations are included
|
||||
if !strings.Contains(result, "## Sources") {
|
||||
t.Error("Expected result to contain Sources section")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[AI Research Advances 2025](https://example.com/ai-research)") {
|
||||
t.Error("Expected result to contain first citation")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[Technology News Today](https://another-source.com/tech-news)") {
|
||||
t.Error("Expected result to contain second citation")
|
||||
}
|
||||
|
||||
// Verify deduplication - should only have 2 unique citations, not 3
|
||||
citationCount := strings.Count(result, "- [")
|
||||
if citationCount != 2 {
|
||||
t.Errorf("Expected 2 unique citations, got %d", citationCount)
|
||||
}
|
||||
}
|
||||
|
||||
300
plugins/ai/anthropic/oauth.go
Normal file
300
plugins/ai/anthropic/oauth.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/openai"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
@@ -29,11 +30,15 @@ type Client struct {
|
||||
|
||||
func (oi *Client) configure() (err error) {
|
||||
oi.apiDeployments = strings.Split(oi.ApiDeployments.Value, ",")
|
||||
config := goopenai.DefaultAzureConfig(oi.ApiKey.Value, oi.ApiBaseURL.Value)
|
||||
if oi.ApiVersion.Value != "" {
|
||||
config.APIVersion = oi.ApiVersion.Value
|
||||
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
|
||||
if oi.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
|
||||
}
|
||||
oi.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
if oi.ApiVersion.Value != "" {
|
||||
opts = append(opts, option.WithQuery("api-version", oi.ApiVersion.Value))
|
||||
}
|
||||
client := openaiapi.NewClient(opts...)
|
||||
oi.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) {
|
||||
}
|
||||
|
||||
// SendStream sends the messages to the the Bedrock ConverseStream API
|
||||
func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
// Ensure channel is closed on all exit paths to prevent goroutine leaks
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -208,7 +208,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
|
||||
}
|
||||
|
||||
// Send sends the messages the Bedrock Converse API
|
||||
func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
|
||||
messages := c.toMessages(msgs)
|
||||
|
||||
@@ -249,12 +249,12 @@ func (c *BedrockClient) NeedsRawMode(modelName string) bool {
|
||||
// Bedrock Converse Message type.
|
||||
// The system role messages are mapped to the user role as they contain a mix of system messages,
|
||||
// pattern content and user input.
|
||||
func (c *BedrockClient) toMessages(inputMessages []*goopenai.ChatCompletionMessage) (messages []types.Message) {
|
||||
func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) {
|
||||
for _, msg := range inputMessages {
|
||||
roles := map[string]types.ConversationRole{
|
||||
goopenai.ChatMessageRoleUser: types.ConversationRoleUser,
|
||||
goopenai.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
|
||||
goopenai.ChatMessageRoleSystem: types.ConversationRoleUser,
|
||||
chat.ChatMessageRoleUser: types.ConversationRoleUser,
|
||||
chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
|
||||
chat.ChatMessageRoleSystem: types.ConversationRoleUser,
|
||||
}
|
||||
|
||||
role, ok := roles[msg.Role]
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -24,14 +24,14 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return []string{"dry-run-model"}, nil
|
||||
}
|
||||
|
||||
func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage) string {
|
||||
func (c *Client) formatMultiContentMessage(msg *chat.ChatCompletionMessage) string {
|
||||
var builder strings.Builder
|
||||
|
||||
if len(msg.MultiContent) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", msg.Role))
|
||||
for _, part := range msg.MultiContent {
|
||||
builder.WriteString(fmt.Sprintf(" - Type: %s\n", part.Type))
|
||||
if part.Type == goopenai.ChatMessagePartTypeImageURL {
|
||||
if part.Type == chat.ChatMessagePartTypeImageURL {
|
||||
builder.WriteString(fmt.Sprintf(" Image URL: %s\n", part.ImageURL.URL))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf(" Text: %s\n", part.Text))
|
||||
@@ -45,16 +45,16 @@ func (c *Client) formatMultiContentMessage(msg *goopenai.ChatCompletionMessage)
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (c *Client) formatMessages(msgs []*goopenai.ChatCompletionMessage) string {
|
||||
func (c *Client) formatMessages(msgs []*chat.ChatCompletionMessage) string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case goopenai.ChatMessageRoleSystem:
|
||||
case chat.ChatMessageRoleSystem:
|
||||
builder.WriteString(fmt.Sprintf("System:\n%s\n\n", msg.Content))
|
||||
case goopenai.ChatMessageRoleAssistant:
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
builder.WriteString(c.formatMultiContentMessage(msg))
|
||||
case goopenai.ChatMessageRoleUser:
|
||||
case chat.ChatMessageRoleUser:
|
||||
builder.WriteString(c.formatMultiContentMessage(msg))
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content))
|
||||
@@ -76,11 +76,20 @@ func (c *Client) formatOptions(opts *common.ChatOptions) string {
|
||||
if opts.ModelContextLength != 0 {
|
||||
builder.WriteString(fmt.Sprintf("ModelContextLength: %d\n", opts.ModelContextLength))
|
||||
}
|
||||
if opts.Search {
|
||||
builder.WriteString("Search: enabled\n")
|
||||
if opts.SearchLocation != "" {
|
||||
builder.WriteString(fmt.Sprintf("SearchLocation: %s\n", opts.SearchLocation))
|
||||
}
|
||||
}
|
||||
if opts.ImageFile != "" {
|
||||
builder.WriteString(fmt.Sprintf("ImageFile: %s\n", opts.ImageFile))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Dry run: Would send the following request:\n\n")
|
||||
builder.WriteString(c.formatMessages(msgs))
|
||||
@@ -91,7 +100,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(_ context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
func (c *Client) Send(_ context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
fmt.Println("Dry run: Would send the following request:")
|
||||
fmt.Print(c.formatMessages(msgs))
|
||||
fmt.Print(c.formatOptions(opts))
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Test generated using Keploy
|
||||
@@ -33,7 +33,7 @@ func TestSetup_ReturnsNil(t *testing.T) {
|
||||
// Test generated using Keploy
|
||||
func TestSendStream_SendsMessages(t *testing.T) {
|
||||
client := NewClient()
|
||||
msgs := []*openai.ChatCompletionMessage{
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Test message"},
|
||||
}
|
||||
opts := &common.ChatOptions{
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/openai"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
@@ -32,10 +32,12 @@ type Client struct {
|
||||
func (oi *Client) configure() (err error) {
|
||||
oi.apiModels = strings.Split(oi.ApiModels.Value, ",")
|
||||
|
||||
config := goopenai.DefaultConfig("")
|
||||
config.BaseURL = oi.ApiBaseURL.Value
|
||||
|
||||
oi.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
opts := []option.RequestOption{option.WithAPIKey(oi.ApiKey.Value)}
|
||||
if oi.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(oi.ApiBaseURL.Value))
|
||||
}
|
||||
client := openaiapi.NewClient(opts...)
|
||||
oi.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
@@ -60,7 +60,7 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
systemInstruction, messages := toMessages(msgs)
|
||||
|
||||
var client *genai.Client
|
||||
@@ -91,7 +91,7 @@ func (o *Client) buildModelNameFull(modelName string) string {
|
||||
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||
@@ -147,7 +147,7 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func toMessages(msgs []*goopenai.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
|
||||
func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
|
||||
if len(msgs) >= 2 {
|
||||
systemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -87,7 +87,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
@@ -173,7 +173,7 @@ func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
|
||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (content string, err error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
ollamaapi "github.com/ollama/ollama/api"
|
||||
"github.com/samber/lo"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
@@ -97,7 +97,7 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
|
||||
req := o.createChatRequest(msgs, opts)
|
||||
|
||||
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
|
||||
@@ -115,7 +115,7 @@ func (o *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
bf := false
|
||||
|
||||
req := o.createChatRequest(msgs, opts)
|
||||
@@ -132,8 +132,8 @@ func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) createChatRequest(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
|
||||
messages := lo.Map(msgs, func(message *goopenai.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
|
||||
func (o *Client) createChatRequest(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret ollamaapi.ChatRequest) {
|
||||
messages := lo.Map(msgs, func(message *chat.ChatCompletionMessage, _ int) (ret ollamaapi.Message) {
|
||||
return ollamaapi.Message{Role: message.Role, Content: message.Content}
|
||||
})
|
||||
|
||||
|
||||
116
plugins/ai/openai/chat_completions.go
Normal file
116
plugins/ai/openai/chat_completions.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package openai
|
||||
|
||||
// This file contains helper methods for the Chat Completions API.
|
||||
// These methods are used as fallbacks for OpenAI-compatible providers
|
||||
// that don't support the newer Responses API (e.g., Groq, Mistral, etc.).
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/shared"
|
||||
)
|
||||
|
||||
// sendChatCompletions sends a request using the Chat Completions API
|
||||
func (o *Client) sendChatCompletions(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := o.buildChatCompletionParams(msgs, opts)
|
||||
|
||||
var resp *openai.ChatCompletion
|
||||
if resp, err = o.ApiClient.Chat.Completions.New(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
ret = resp.Choices[0].Message.Content
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sendStreamChatCompletions sends a streaming request using the Chat Completions API
|
||||
func (o *Client) sendStreamChatCompletions(
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
defer close(channel)
|
||||
|
||||
req := o.buildChatCompletionParams(msgs, opts)
|
||||
stream := o.ApiClient.Chat.Completions.NewStreaming(context.Background(), req)
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
channel <- chunk.Choices[0].Delta.Content
|
||||
}
|
||||
}
|
||||
if stream.Err() == nil {
|
||||
channel <- "\n"
|
||||
}
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
// buildChatCompletionParams builds parameters for the Chat Completions API
|
||||
func (o *Client) buildChatCompletionParams(
|
||||
inputMsgs []*chat.ChatCompletionMessage, opts *common.ChatOptions,
|
||||
) (ret openai.ChatCompletionNewParams) {
|
||||
|
||||
messages := make([]openai.ChatCompletionMessageParamUnion, len(inputMsgs))
|
||||
for i, msgPtr := range inputMsgs {
|
||||
msg := *msgPtr
|
||||
if strings.Contains(opts.Model, "deepseek") && len(inputMsgs) == 1 && msg.Role == chat.ChatMessageRoleSystem {
|
||||
msg.Role = chat.ChatMessageRoleUser
|
||||
}
|
||||
messages[i] = o.convertChatMessage(msg)
|
||||
}
|
||||
|
||||
ret = openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(opts.Model),
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
if opts.MaxTokens != 0 {
|
||||
ret.MaxTokens = openai.Int(int64(opts.MaxTokens))
|
||||
}
|
||||
if opts.PresencePenalty != 0 {
|
||||
ret.PresencePenalty = openai.Float(opts.PresencePenalty)
|
||||
}
|
||||
if opts.FrequencyPenalty != 0 {
|
||||
ret.FrequencyPenalty = openai.Float(opts.FrequencyPenalty)
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
ret.Seed = openai.Int(int64(opts.Seed))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// convertChatMessage converts fabric chat message to OpenAI chat completion message
|
||||
func (o *Client) convertChatMessage(msg chat.ChatCompletionMessage) openai.ChatCompletionMessageParamUnion {
|
||||
result := convertMessageCommon(msg)
|
||||
|
||||
switch result.Role {
|
||||
case chat.ChatMessageRoleSystem:
|
||||
return openai.SystemMessage(result.Content)
|
||||
case chat.ChatMessageRoleUser:
|
||||
// Handle multi-content messages (text + images)
|
||||
if result.HasMultiContent {
|
||||
var parts []openai.ChatCompletionContentPartUnionParam
|
||||
for _, p := range result.MultiContent {
|
||||
switch p.Type {
|
||||
case chat.ChatMessagePartTypeText:
|
||||
parts = append(parts, openai.TextContentPart(p.Text))
|
||||
case chat.ChatMessagePartTypeImageURL:
|
||||
parts = append(parts, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{URL: p.ImageURL.URL}))
|
||||
}
|
||||
}
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
return openai.UserMessage(result.Content)
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
return openai.AssistantMessage(result.Content)
|
||||
default:
|
||||
return openai.UserMessage(result.Content)
|
||||
}
|
||||
}
|
||||
21
plugins/ai/openai/message_conversion.go
Normal file
21
plugins/ai/openai/message_conversion.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
import "github.com/danielmiessler/fabric/chat"
|
||||
|
||||
// MessageConversionResult holds the common conversion result
|
||||
type MessageConversionResult struct {
|
||||
Role string
|
||||
Content string
|
||||
MultiContent []chat.ChatMessagePart
|
||||
HasMultiContent bool
|
||||
}
|
||||
|
||||
// convertMessageCommon extracts common conversion logic
|
||||
func convertMessageCommon(msg chat.ChatCompletionMessage) MessageConversionResult {
|
||||
return MessageConversionResult{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
MultiContent: msg.MultiContent,
|
||||
HasMultiContent: len(msg.MultiContent) > 0,
|
||||
}
|
||||
}
|
||||
@@ -2,20 +2,23 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/pagination"
|
||||
"github.com/openai/openai-go/responses"
|
||||
"github.com/openai/openai-go/shared"
|
||||
"github.com/openai/openai-go/shared/constant"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
return NewClientCompatible("OpenAI", "https://api.openai.com/v1", nil)
|
||||
return NewClientCompatibleWithResponses("OpenAI", "https://api.openai.com/v1", true, nil)
|
||||
}
|
||||
|
||||
func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
|
||||
@@ -28,6 +31,17 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust
|
||||
return
|
||||
}
|
||||
|
||||
func NewClientCompatibleWithResponses(vendorName string, defaultBaseUrl string, implementsResponses bool, configureCustom func() error) (ret *Client) {
|
||||
ret = NewClientCompatibleNoSetupQuestions(vendorName, configureCustom)
|
||||
|
||||
ret.ApiKey = ret.AddSetupQuestion("API Key", true)
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ImplementsResponses = implementsResponses
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func NewClientCompatibleNoSetupQuestions(vendorName string, configureCustom func() error) (ret *Client) {
|
||||
ret = &Client{}
|
||||
|
||||
@@ -46,82 +60,98 @@ func NewClientCompatibleNoSetupQuestions(vendorName string, configureCustom func
|
||||
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiKey *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiClient *goopenai.Client
|
||||
ApiKey *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiClient *openai.Client
|
||||
ImplementsResponses bool // Whether this provider supports the Responses API
|
||||
}
|
||||
|
||||
func (o *Client) configure() (ret error) {
|
||||
config := goopenai.DefaultConfig(o.ApiKey.Value)
|
||||
opts := []option.RequestOption{option.WithAPIKey(o.ApiKey.Value)}
|
||||
if o.ApiBaseURL.Value != "" {
|
||||
config.BaseURL = o.ApiBaseURL.Value
|
||||
opts = append(opts, option.WithBaseURL(o.ApiBaseURL.Value))
|
||||
}
|
||||
o.ApiClient = goopenai.NewClientWithConfig(config)
|
||||
client := openai.NewClient(opts...)
|
||||
o.ApiClient = &client
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) ListModels() (ret []string, err error) {
|
||||
var models goopenai.ModelsList
|
||||
if models, err = o.ApiClient.ListModels(context.Background()); err != nil {
|
||||
var page *pagination.Page[openai.Model]
|
||||
if page, err = o.ApiClient.Models.List(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
model := models.Models
|
||||
for _, mod := range model {
|
||||
for _, mod := range page.Data {
|
||||
ret = append(ret, mod.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) SendStream(
|
||||
msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
req := o.buildChatCompletionRequest(msgs, opts)
|
||||
req.Stream = true
|
||||
// Use Responses API for OpenAI, Chat Completions API for other providers
|
||||
if o.supportsResponsesAPI() {
|
||||
return o.sendStreamResponses(msgs, opts, channel)
|
||||
}
|
||||
return o.sendStreamChatCompletions(msgs, opts, channel)
|
||||
}
|
||||
|
||||
var stream *goopenai.ChatCompletionStream
|
||||
if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil {
|
||||
fmt.Printf("ChatCompletionStream error: %v\n", err)
|
||||
func (o *Client) sendStreamResponses(
|
||||
msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string,
|
||||
) (err error) {
|
||||
defer close(channel)
|
||||
|
||||
req := o.buildResponseParams(msgs, opts)
|
||||
stream := o.ApiClient.Responses.NewStreaming(context.Background(), req)
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
switch event.Type {
|
||||
case string(constant.ResponseOutputTextDelta("").Default()):
|
||||
channel <- event.AsResponseOutputTextDelta().Delta
|
||||
case string(constant.ResponseOutputTextDone("").Default()):
|
||||
channel <- event.AsResponseOutputTextDone().Text
|
||||
}
|
||||
}
|
||||
if stream.Err() == nil {
|
||||
channel <- "\n"
|
||||
}
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
// Use Responses API for OpenAI, Chat Completions API for other providers
|
||||
if o.supportsResponsesAPI() {
|
||||
return o.sendResponses(ctx, msgs, opts)
|
||||
}
|
||||
return o.sendChatCompletions(ctx, msgs, opts)
|
||||
}
|
||||
|
||||
func (o *Client) sendResponses(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
// Validate model supports image generation if image file is specified
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
return "", fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
}
|
||||
|
||||
req := o.buildResponseParams(msgs, opts)
|
||||
|
||||
var resp *responses.Response
|
||||
if resp, err = o.ApiClient.Responses.New(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
var response goopenai.ChatCompletionStreamResponse
|
||||
if response, err = stream.Recv(); err == nil {
|
||||
if len(response.Choices) > 0 {
|
||||
channel <- response.Choices[0].Delta.Content
|
||||
} else {
|
||||
channel <- "\n"
|
||||
close(channel)
|
||||
break
|
||||
}
|
||||
} else if errors.Is(err, io.EOF) {
|
||||
channel <- "\n"
|
||||
close(channel)
|
||||
err = nil
|
||||
break
|
||||
} else if err != nil {
|
||||
fmt.Printf("\nStream error: %v\n", err)
|
||||
break
|
||||
}
|
||||
// Extract and save images if requested
|
||||
if err = o.extractAndSaveImages(resp, opts); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ret = o.extractText(resp)
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
|
||||
req := o.buildChatCompletionRequest(msgs, opts)
|
||||
|
||||
var resp goopenai.ChatCompletionResponse
|
||||
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
ret = resp.Choices[0].Message.Content
|
||||
slog.Debug("SystemFingerprint: " + resp.SystemFingerprint)
|
||||
}
|
||||
return
|
||||
// supportsResponsesAPI determines if the provider supports the new Responses API
|
||||
func (o *Client) supportsResponsesAPI() bool {
|
||||
return o.ImplementsResponses
|
||||
}
|
||||
|
||||
func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
@@ -135,8 +165,6 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
"gpt-4o-mini-search-preview-2025-03-11",
|
||||
"gpt-4o-search-preview",
|
||||
"gpt-4o-search-preview-2025-03-11",
|
||||
"o4-mini-deep-research",
|
||||
"o4-mini-deep-research-2025-06-26",
|
||||
}
|
||||
for _, prefix := range openaiModelsPrefixes {
|
||||
if strings.HasPrefix(modelName, prefix) {
|
||||
@@ -146,57 +174,136 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
return slices.Contains(openAIModelsNeedingRaw, modelName)
|
||||
}
|
||||
|
||||
func (o *Client) buildChatCompletionRequest(
|
||||
inputMsgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions,
|
||||
) (ret goopenai.ChatCompletionRequest) {
|
||||
func (o *Client) buildResponseParams(
|
||||
inputMsgs []*chat.ChatCompletionMessage, opts *common.ChatOptions,
|
||||
) (ret responses.ResponseNewParams) {
|
||||
|
||||
// Create a new slice for messages to be sent, converting from []*Msg to []Msg.
|
||||
// This also serves as a mutable copy for provider-specific modifications.
|
||||
messagesForRequest := make([]goopenai.ChatCompletionMessage, len(inputMsgs))
|
||||
items := make([]responses.ResponseInputItemUnionParam, len(inputMsgs))
|
||||
for i, msgPtr := range inputMsgs {
|
||||
messagesForRequest[i] = *msgPtr // Dereference and copy
|
||||
msg := *msgPtr
|
||||
if strings.Contains(opts.Model, "deepseek") && len(inputMsgs) == 1 && msg.Role == chat.ChatMessageRoleSystem {
|
||||
msg.Role = chat.ChatMessageRoleUser
|
||||
}
|
||||
items[i] = convertMessage(msg)
|
||||
}
|
||||
|
||||
// Provider-specific modification for DeepSeek:
|
||||
// DeepSeek requires the last message to be a user message.
|
||||
// If fabric constructs a single system message (common when a pattern includes user input),
|
||||
// we change its role to user for DeepSeek.
|
||||
if strings.Contains(opts.Model, "deepseek") { // Heuristic to identify DeepSeek models
|
||||
if len(messagesForRequest) == 1 && messagesForRequest[0].Role == goopenai.ChatMessageRoleSystem {
|
||||
messagesForRequest[0].Role = goopenai.ChatMessageRoleUser
|
||||
}
|
||||
// Note: This handles the most common case arising from pattern usage.
|
||||
// More complex scenarios where a multi-message sequence ends in 'system'
|
||||
// are not currently expected from chatter.go's BuildSession logic for OpenAI providers
|
||||
// but might require further rules if they arise.
|
||||
ret = responses.ResponseNewParams{
|
||||
Model: shared.ResponsesModel(opts.Model),
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: items,
|
||||
},
|
||||
}
|
||||
|
||||
if opts.Raw {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messagesForRequest,
|
||||
// Add tools if enabled
|
||||
var tools []responses.ToolUnionParam
|
||||
|
||||
// Add web search tool if enabled
|
||||
if opts.Search {
|
||||
webSearchTool := responses.ToolParamOfWebSearchPreview("web_search_preview")
|
||||
|
||||
// Add user location if provided
|
||||
if opts.SearchLocation != "" {
|
||||
webSearchTool.OfWebSearchPreview.UserLocation = responses.WebSearchToolUserLocationParam{
|
||||
Type: "approximate",
|
||||
Timezone: openai.String(opts.SearchLocation),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if opts.Seed == 0 {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messagesForRequest,
|
||||
}
|
||||
} else {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messagesForRequest,
|
||||
Seed: &opts.Seed,
|
||||
}
|
||||
|
||||
tools = append(tools, webSearchTool)
|
||||
}
|
||||
|
||||
// Add image generation tool if needed
|
||||
tools = o.addImageGenerationTool(opts, tools)
|
||||
|
||||
if len(tools) > 0 {
|
||||
ret.Tools = tools
|
||||
}
|
||||
|
||||
if !opts.Raw {
|
||||
ret.Temperature = openai.Float(opts.Temperature)
|
||||
ret.TopP = openai.Float(opts.TopP)
|
||||
if opts.MaxTokens != 0 {
|
||||
ret.MaxOutputTokens = openai.Int(int64(opts.MaxTokens))
|
||||
}
|
||||
|
||||
// Add parameters not officially supported by Responses API as extra fields
|
||||
extraFields := make(map[string]any)
|
||||
if opts.PresencePenalty != 0 {
|
||||
extraFields["presence_penalty"] = opts.PresencePenalty
|
||||
}
|
||||
if opts.FrequencyPenalty != 0 {
|
||||
extraFields["frequency_penalty"] = opts.FrequencyPenalty
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
extraFields["seed"] = opts.Seed
|
||||
}
|
||||
if len(extraFields) > 0 {
|
||||
ret.SetExtraFields(extraFields)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertMessage(msg chat.ChatCompletionMessage) responses.ResponseInputItemUnionParam {
|
||||
result := convertMessageCommon(msg)
|
||||
role := responses.EasyInputMessageRole(result.Role)
|
||||
|
||||
if result.HasMultiContent {
|
||||
var parts []responses.ResponseInputContentUnionParam
|
||||
for _, p := range result.MultiContent {
|
||||
switch p.Type {
|
||||
case chat.ChatMessagePartTypeText:
|
||||
parts = append(parts, responses.ResponseInputContentParamOfInputText(p.Text))
|
||||
case chat.ChatMessagePartTypeImageURL:
|
||||
part := responses.ResponseInputContentParamOfInputImage(responses.ResponseInputImageDetailAuto)
|
||||
if part.OfInputImage != nil {
|
||||
part.OfInputImage.ImageURL = openai.String(p.ImageURL.URL)
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
contentList := responses.ResponseInputMessageContentListParam(parts)
|
||||
return responses.ResponseInputItemParamOfMessage(contentList, role)
|
||||
}
|
||||
return responses.ResponseInputItemParamOfMessage(result.Content, role)
|
||||
}
|
||||
|
||||
func (o *Client) extractText(resp *responses.Response) (ret string) {
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool) // To avoid duplicate citations
|
||||
|
||||
for _, item := range resp.Output {
|
||||
if item.Type == "message" {
|
||||
for _, c := range item.Content {
|
||||
if c.Type == "output_text" {
|
||||
outputText := c.AsOutputText()
|
||||
textParts = append(textParts, outputText.Text)
|
||||
|
||||
// Extract citations from annotations
|
||||
for _, annotation := range outputText.Annotations {
|
||||
if annotation.Type == "url_citation" {
|
||||
urlCitation := annotation.AsURLCitation()
|
||||
citationKey := urlCitation.URL + "|" + urlCitation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := fmt.Sprintf("- [%s](%s)", urlCitation.Title, urlCitation.URL)
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ret = strings.Join(textParts, "")
|
||||
|
||||
// Append citations if any were found
|
||||
if len(citations) > 0 {
|
||||
ret += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
146
plugins/ai/openai/openai_image.go
Normal file
146
plugins/ai/openai/openai_image.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package openai
|
||||
|
||||
// This file contains helper methods for image generation and processing
|
||||
// using OpenAI's Responses API and Image API.
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/responses"
|
||||
)
|
||||
|
||||
// ImageGenerationResponseType is the type used for image generation calls in responses
|
||||
const ImageGenerationResponseType = "image_generation_call"
|
||||
const ImageGenerationToolType = "image_generation"
|
||||
|
||||
// ImageGenerationSupportedModels lists all models that support image generation
|
||||
var ImageGenerationSupportedModels = []string{
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"o3",
|
||||
}
|
||||
|
||||
// supportsImageGeneration checks if the given model supports the image_generation tool
|
||||
func supportsImageGeneration(model string) bool {
|
||||
for _, supportedModel := range ImageGenerationSupportedModels {
|
||||
if model == supportedModel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getOutputFormatFromExtension determines the API output format based on file extension
|
||||
func getOutputFormatFromExtension(imagePath string) string {
|
||||
if imagePath == "" {
|
||||
return "png" // Default format
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(imagePath))
|
||||
switch ext {
|
||||
case ".png":
|
||||
return "png"
|
||||
case ".webp":
|
||||
return "webp"
|
||||
case ".jpg":
|
||||
return "jpeg"
|
||||
case ".jpeg":
|
||||
return "jpeg"
|
||||
default:
|
||||
return "png" // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// addImageGenerationTool adds the image generation tool to the request if needed
|
||||
func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []responses.ToolUnionParam) []responses.ToolUnionParam {
|
||||
// Check if the request seems to be asking for image generation
|
||||
if o.shouldUseImageGeneration(opts) {
|
||||
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
|
||||
|
||||
// Build the image generation tool with user parameters
|
||||
imageGenTool := responses.ToolUnionParam{
|
||||
OfImageGeneration: &responses.ToolImageGenerationParam{
|
||||
Type: ImageGenerationToolType,
|
||||
Model: "gpt-image-1",
|
||||
OutputFormat: outputFormat,
|
||||
},
|
||||
}
|
||||
|
||||
// Set quality if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageQuality != "" {
|
||||
imageGenTool.OfImageGeneration.Quality = opts.ImageQuality
|
||||
}
|
||||
|
||||
// Set size if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageSize != "" {
|
||||
imageGenTool.OfImageGeneration.Size = opts.ImageSize
|
||||
}
|
||||
|
||||
// Set background if specified by user (otherwise let OpenAI use default)
|
||||
if opts.ImageBackground != "" {
|
||||
imageGenTool.OfImageGeneration.Background = opts.ImageBackground
|
||||
}
|
||||
|
||||
// Set compression if specified by user (only for jpeg/webp)
|
||||
if opts.ImageCompression != 0 {
|
||||
imageGenTool.OfImageGeneration.OutputCompression = param.NewOpt(int64(opts.ImageCompression))
|
||||
}
|
||||
|
||||
tools = append(tools, imageGenTool)
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// shouldUseImageGeneration determines if image generation should be enabled
|
||||
// This is a heuristic based on the presence of --image-file flag
|
||||
func (o *Client) shouldUseImageGeneration(opts *common.ChatOptions) bool {
|
||||
return opts.ImageFile != ""
|
||||
}
|
||||
|
||||
// extractAndSaveImages extracts generated images from the response and saves them
|
||||
func (o *Client) extractAndSaveImages(resp *responses.Response, opts *common.ChatOptions) error {
|
||||
if opts.ImageFile == "" {
|
||||
return nil // No image file specified, skip saving
|
||||
}
|
||||
|
||||
// Extract image data from response
|
||||
for _, item := range resp.Output {
|
||||
if item.Type == ImageGenerationResponseType {
|
||||
imageCall := item.AsImageGenerationCall()
|
||||
if imageCall.Status == "completed" && imageCall.Result != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageCall.Result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image data: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(opts.ImageFile)
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save image to file
|
||||
if err := os.WriteFile(opts.ImageFile, imageData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to save image to %s: %w", opts.ImageFile, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Image saved to: %s\n", opts.ImageFile)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
444
plugins/ai/openai/openai_image_test.go
Normal file
444
plugins/ai/openai/openai_image_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/openai/openai-go/responses"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShouldUseImageGeneration(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
// Test with image file specified
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
assert.True(t, client.shouldUseImageGeneration(opts), "Should use image generation when image file is specified")
|
||||
|
||||
// Test without image file
|
||||
opts = &common.ChatOptions{
|
||||
ImageFile: "",
|
||||
}
|
||||
assert.False(t, client.shouldUseImageGeneration(opts), "Should not use image generation when no image file is specified")
|
||||
}
|
||||
|
||||
func TestAddImageGenerationTool(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
// Test with image generation enabled
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
tools := []responses.ToolUnionParam{}
|
||||
result := client.addImageGenerationTool(opts, tools)
|
||||
|
||||
assert.Len(t, result, 1, "Should add one image generation tool")
|
||||
assert.NotNil(t, result[0].OfImageGeneration, "Should have image generation tool")
|
||||
assert.Equal(t, "image_generation", string(result[0].OfImageGeneration.Type))
|
||||
assert.Equal(t, "gpt-image-1", result[0].OfImageGeneration.Model)
|
||||
assert.Equal(t, "png", result[0].OfImageGeneration.OutputFormat)
|
||||
|
||||
// Test without image generation
|
||||
opts = &common.ChatOptions{
|
||||
ImageFile: "",
|
||||
}
|
||||
tools = []responses.ToolUnionParam{}
|
||||
result = client.addImageGenerationTool(opts, tools)
|
||||
|
||||
assert.Len(t, result, 0, "Should not add image generation tool when not needed")
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithImageGeneration(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-image-1",
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Generate an image of a cat"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when image generation is enabled")
|
||||
|
||||
// Should have image generation tool
|
||||
hasImageTool := false
|
||||
for _, tool := range params.Tools {
|
||||
if tool.OfImageGeneration != nil {
|
||||
hasImageTool = true
|
||||
assert.Equal(t, "image_generation", string(tool.OfImageGeneration.Type))
|
||||
assert.Equal(t, "gpt-image-1", tool.OfImageGeneration.Model)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithBothSearchAndImage(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-image-1",
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
ImageFile: "output.png",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Search for cat images and generate one"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when both search and image generation are enabled")
|
||||
assert.Len(t, params.Tools, 2, "Should have both search and image generation tools")
|
||||
|
||||
hasSearchTool := false
|
||||
hasImageTool := false
|
||||
|
||||
for _, tool := range params.Tools {
|
||||
if tool.OfWebSearchPreview != nil {
|
||||
hasSearchTool = true
|
||||
}
|
||||
if tool.OfImageGeneration != nil {
|
||||
hasImageTool = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, hasSearchTool, "Should have web search tool")
|
||||
assert.True(t, hasImageTool, "Should have image generation tool")
|
||||
}
|
||||
|
||||
func TestGetOutputFormatFromExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
imagePath string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG extension",
|
||||
imagePath: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP extension",
|
||||
imagePath: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG extension",
|
||||
imagePath: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG extension",
|
||||
imagePath: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Uppercase PNG extension",
|
||||
imagePath: "/tmp/output.PNG",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Mixed case JPEG extension",
|
||||
imagePath: "/tmp/output.JpEg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
imagePath: "",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "No extension",
|
||||
imagePath: "/tmp/output",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "Unsupported extension",
|
||||
imagePath: "/tmp/output.gif",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getOutputFormatFromExtension(tt.imagePath)
|
||||
assert.Equal(t, tt.expectedFormat, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
imageFile string
|
||||
expectedFormat string
|
||||
}{
|
||||
{
|
||||
name: "PNG file",
|
||||
imageFile: "/tmp/output.png",
|
||||
expectedFormat: "png",
|
||||
},
|
||||
{
|
||||
name: "WEBP file",
|
||||
imageFile: "/tmp/output.webp",
|
||||
expectedFormat: "webp",
|
||||
},
|
||||
{
|
||||
name: "JPG file",
|
||||
imageFile: "/tmp/output.jpg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
{
|
||||
name: "JPEG file",
|
||||
imageFile: "/tmp/output.jpeg",
|
||||
expectedFormat: "jpeg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
ImageFile: tt.imageFile,
|
||||
}
|
||||
|
||||
tools := client.addImageGenerationTool(opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1, "Should have one tool")
|
||||
assert.NotNil(t, tools[0].OfImageGeneration, "Should be image generation tool")
|
||||
assert.Equal(t, tt.expectedFormat, tools[0].OfImageGeneration.OutputFormat, "Output format should match file extension")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsImageGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "gpt-4o supports image generation",
|
||||
model: "gpt-4o",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini supports image generation",
|
||||
model: "gpt-4o-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1 supports image generation",
|
||||
model: "gpt-4.1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-mini supports image generation",
|
||||
model: "gpt-4.1-mini",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gpt-4.1-nano supports image generation",
|
||||
model: "gpt-4.1-nano",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o3 supports image generation",
|
||||
model: "o3",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "o1 does not support image generation",
|
||||
model: "o1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o1-mini does not support image generation",
|
||||
model: "o1-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "o3-mini does not support image generation",
|
||||
model: "o3-mini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4 does not support image generation",
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-3.5-turbo does not support image generation",
|
||||
model: "gpt-3.5-turbo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model does not support image generation",
|
||||
model: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := supportsImageGeneration(tt.model)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelValidationLogic(t *testing.T) {
|
||||
t.Run("Unsupported model with image file should return validation error", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
|
||||
err := fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
|
||||
|
||||
assert.Contains(t, err.Error(), "does not support image generation")
|
||||
assert.Contains(t, err.Error(), "o1-mini")
|
||||
assert.Contains(t, err.Error(), "Supported models:")
|
||||
} else {
|
||||
t.Error("Expected validation to trigger")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Supported model with image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
ImageFile: "/tmp/output.png",
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger for supported model")
|
||||
})
|
||||
|
||||
t.Run("Unsupported model without image file should not trigger validation", func(t *testing.T) {
|
||||
opts := &common.ChatOptions{
|
||||
Model: "o1-mini",
|
||||
ImageFile: "", // No image file
|
||||
}
|
||||
|
||||
// Test the validation logic directly
|
||||
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
|
||||
assert.False(t, shouldFail, "Validation should not trigger when no image file is specified")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddImageGenerationToolWithUserParameters(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *common.ChatOptions
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "All parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
ImageSize: "1536x1024",
|
||||
ImageQuality: "high",
|
||||
ImageBackground: "transparent",
|
||||
ImageCompression: 0, // Not applicable for PNG
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1536x1024",
|
||||
"quality": "high",
|
||||
"background": "transparent",
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JPEG with compression",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.jpg",
|
||||
ImageSize: "1024x1024",
|
||||
ImageQuality: "medium",
|
||||
ImageBackground: "opaque",
|
||||
ImageCompression: 75,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"size": "1024x1024",
|
||||
"quality": "medium",
|
||||
"background": "opaque",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": int64(75),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Only some parameters specified",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.webp",
|
||||
ImageQuality: "low",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"quality": "low",
|
||||
"output_format": "webp",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No parameters specified (defaults)",
|
||||
opts: &common.ChatOptions{
|
||||
ImageFile: "/tmp/test.png",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"output_format": "png",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tools := client.addImageGenerationTool(tt.opts, []responses.ToolUnionParam{})
|
||||
|
||||
assert.Len(t, tools, 1)
|
||||
assert.NotNil(t, tools[0].OfImageGeneration)
|
||||
|
||||
tool := tools[0].OfImageGeneration
|
||||
|
||||
// Check required fields
|
||||
assert.Equal(t, "gpt-image-1", tool.Model)
|
||||
assert.Equal(t, tt.expected["output_format"], tool.OutputFormat)
|
||||
|
||||
// Check optional fields
|
||||
if expectedSize, ok := tt.expected["size"]; ok {
|
||||
assert.Equal(t, expectedSize, tool.Size)
|
||||
} else {
|
||||
assert.Empty(t, tool.Size, "Size should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedQuality, ok := tt.expected["quality"]; ok {
|
||||
assert.Equal(t, expectedQuality, tool.Quality)
|
||||
} else {
|
||||
assert.Empty(t, tool.Quality, "Quality should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedBackground, ok := tt.expected["background"]; ok {
|
||||
assert.Equal(t, expectedBackground, tool.Background)
|
||||
} else {
|
||||
assert.Empty(t, tool.Background, "Background should not be set when not specified")
|
||||
}
|
||||
|
||||
if expectedCompression, ok := tt.expected["output_compression"]; ok {
|
||||
assert.Equal(t, expectedCompression, tool.OutputCompression.Value)
|
||||
} else {
|
||||
assert.Equal(t, int64(0), tool.OutputCompression.Value, "Compression should not be set when not specified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,102 +1,177 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
openai "github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/responses"
|
||||
"github.com/openai/openai-go/shared"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
||||
func TestBuildResponseRequestWithMaxTokens(t *testing.T) {
|
||||
|
||||
var msgs []*goopenai.ChatCompletionMessage
|
||||
var msgs []*chat.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
||||
msgs = append(msgs, &chat.ChatCompletionMessage{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
}
|
||||
|
||||
opts := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: &opts.Seed,
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
Raw: false,
|
||||
MaxTokens: 50,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
request := client.buildResponseParams(msgs, opts)
|
||||
assert.Equal(t, shared.ResponsesModel(opts.Model), request.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
|
||||
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
|
||||
assert.Equal(t, openai.Int(int64(opts.MaxTokens)), request.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
||||
func TestBuildResponseRequestNoMaxTokens(t *testing.T) {
|
||||
|
||||
var msgs []*goopenai.ChatCompletionMessage
|
||||
var msgs []*chat.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &goopenai.ChatCompletionMessage{
|
||||
msgs = append(msgs, &chat.ChatCompletionMessage{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
}
|
||||
|
||||
opts := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: nil,
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
Raw: false,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
request := client.buildResponseParams(msgs, opts)
|
||||
assert.Equal(t, shared.ResponsesModel(opts.Model), request.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), request.Temperature)
|
||||
assert.Equal(t, openai.Float(opts.TopP), request.TopP)
|
||||
assert.False(t, request.MaxOutputTokens.Valid())
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithoutSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: false,
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.Nil(t, params.Tools, "Expected no tools when search is disabled")
|
||||
assert.Equal(t, shared.ResponsesModel(opts.Model), params.Model)
|
||||
assert.Equal(t, openai.Float(opts.Temperature), params.Temperature)
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "What's the weather today?"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when search is enabled")
|
||||
assert.Len(t, params.Tools, 1, "Expected exactly one tool")
|
||||
|
||||
tool := params.Tools[0]
|
||||
assert.NotNil(t, tool.OfWebSearchPreview, "Expected web search tool")
|
||||
assert.Equal(t, responses.WebSearchToolType("web_search_preview"), tool.OfWebSearchPreview.Type)
|
||||
}
|
||||
|
||||
func TestBuildResponseParams_WithSearchAndLocation(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: "user", Content: "What's the weather in San Francisco?"},
|
||||
}
|
||||
|
||||
params := client.buildResponseParams(msgs, opts)
|
||||
|
||||
assert.NotNil(t, params.Tools, "Expected tools when search is enabled")
|
||||
tool := params.Tools[0]
|
||||
assert.NotNil(t, tool.OfWebSearchPreview, "Expected web search tool")
|
||||
|
||||
userLocation := tool.OfWebSearchPreview.UserLocation
|
||||
assert.Equal(t, "approximate", string(userLocation.Type))
|
||||
assert.True(t, userLocation.Timezone.Valid(), "Expected timezone to be set")
|
||||
assert.Equal(t, opts.SearchLocation, userLocation.Timezone.Value)
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
// Test the citation formatting logic by simulating the citation extraction
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool)
|
||||
|
||||
// Simulate text content
|
||||
textParts = append(textParts, "Based on recent research, artificial intelligence is advancing rapidly.")
|
||||
|
||||
// Simulate citations (as they would be extracted from OpenAI response)
|
||||
mockCitations := []struct {
|
||||
URL string
|
||||
Title string
|
||||
}{
|
||||
{"https://example.com/ai-research", "AI Research Advances 2025"},
|
||||
{"https://another-source.com/tech-news", "Technology News Today"},
|
||||
{"https://example.com/ai-research", "AI Research Advances 2025"}, // Duplicate to test deduplication
|
||||
}
|
||||
|
||||
for _, citation := range mockCitations {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(textParts, "")
|
||||
if len(citations) > 0 {
|
||||
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
// Verify the result contains the expected text
|
||||
expectedText := "Based on recent research, artificial intelligence is advancing rapidly."
|
||||
assert.Contains(t, result, expectedText, "Expected result to contain original text")
|
||||
|
||||
// Verify citations are included
|
||||
assert.Contains(t, result, "## Sources", "Expected result to contain Sources section")
|
||||
assert.Contains(t, result, "[AI Research Advances 2025](https://example.com/ai-research)", "Expected result to contain first citation")
|
||||
assert.Contains(t, result, "[Technology News Today](https://another-source.com/tech-news)", "Expected result to contain second citation")
|
||||
|
||||
// Verify deduplication - should only have 2 unique citations, not 3
|
||||
citationCount := strings.Count(result, "- [")
|
||||
assert.Equal(t, 2, citationCount, "Expected 2 unique citations")
|
||||
}
|
||||
|
||||
@@ -9,8 +9,9 @@ import (
|
||||
|
||||
// ProviderConfig defines the configuration for an OpenAI-compatible API provider
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
BaseURL string
|
||||
Name string
|
||||
BaseURL string
|
||||
ImplementsResponses bool // Whether the provider supports OpenAI's new Responses API
|
||||
}
|
||||
|
||||
// Client is the common structure for all OpenAI-compatible providers
|
||||
@@ -21,51 +22,66 @@ type Client struct {
|
||||
// NewClient creates a new OpenAI-compatible client for the specified provider
|
||||
func NewClient(providerConfig ProviderConfig) *Client {
|
||||
client := &Client{}
|
||||
client.Client = openai.NewClientCompatible(providerConfig.Name, providerConfig.BaseURL, nil)
|
||||
client.Client = openai.NewClientCompatibleWithResponses(
|
||||
providerConfig.Name,
|
||||
providerConfig.BaseURL,
|
||||
providerConfig.ImplementsResponses,
|
||||
nil,
|
||||
)
|
||||
return client
|
||||
}
|
||||
|
||||
// ProviderMap is a map of provider name to ProviderConfig for O(1) lookup
|
||||
var ProviderMap = map[string]ProviderConfig{
|
||||
"AIML": {
|
||||
Name: "AIML",
|
||||
BaseURL: "https://api.aimlapi.com/v1",
|
||||
Name: "AIML",
|
||||
BaseURL: "https://api.aimlapi.com/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"Cerebras": {
|
||||
Name: "Cerebras",
|
||||
BaseURL: "https://api.cerebras.ai/v1",
|
||||
Name: "Cerebras",
|
||||
BaseURL: "https://api.cerebras.ai/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"DeepSeek": {
|
||||
Name: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Name: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"GrokAI": {
|
||||
Name: "GrokAI",
|
||||
BaseURL: "https://api.x.ai/v1",
|
||||
Name: "GrokAI",
|
||||
BaseURL: "https://api.x.ai/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"Groq": {
|
||||
Name: "Groq",
|
||||
BaseURL: "https://api.groq.com/openai/v1",
|
||||
Name: "Groq",
|
||||
BaseURL: "https://api.groq.com/openai/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"Langdock": {
|
||||
Name: "Langdock",
|
||||
BaseURL: "https://api.langdock.com/openai/{{REGION=us}}/v1",
|
||||
Name: "Langdock",
|
||||
BaseURL: "https://api.langdock.com/openai/{{REGION=us}}/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"LiteLLM": {
|
||||
Name: "LiteLLM",
|
||||
BaseURL: "http://localhost:4000",
|
||||
Name: "LiteLLM",
|
||||
BaseURL: "http://localhost:4000",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"Mistral": {
|
||||
Name: "Mistral",
|
||||
BaseURL: "https://api.mistral.ai/v1",
|
||||
Name: "Mistral",
|
||||
BaseURL: "https://api.mistral.ai/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"OpenRouter": {
|
||||
Name: "OpenRouter",
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
Name: "OpenRouter",
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
"SiliconCloud": {
|
||||
Name: "SiliconCloud",
|
||||
BaseURL: "https://api.siliconflow.cn/v1",
|
||||
Name: "SiliconCloud",
|
||||
BaseURL: "https://api.siliconflow.cn/v1",
|
||||
ImplementsResponses: false,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
perplexity "github.com/sgaunet/perplexity-go/v2"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,7 +61,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
if c.client == nil {
|
||||
if err := c.Configure(); err != nil {
|
||||
return "", fmt.Errorf("failed to configure Perplexity client: %w", err)
|
||||
@@ -120,7 +120,7 @@ func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessag
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
if c.client == nil {
|
||||
if err := c.Configure(); err != nil {
|
||||
close(channel) // Ensure channel is closed on error
|
||||
|
||||
@@ -3,8 +3,8 @@ package ai
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
type Vendor interface {
|
||||
plugins.Plugin
|
||||
ListModels() ([]string, error)
|
||||
SendStream([]*goopenai.ChatCompletionMessage, *common.ChatOptions, chan string) error
|
||||
Send(context.Context, []*goopenai.ChatCompletionMessage, *common.ChatOptions) (string, error)
|
||||
SendStream([]*chat.ChatCompletionMessage, *common.ChatOptions, chan string) error
|
||||
Send(context.Context, []*chat.ChatCompletionMessage, *common.ChatOptions) (string, error)
|
||||
NeedsRawMode(modelName string) bool
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
@@ -19,6 +20,7 @@ func NewDb(dir string) (db *Db) {
|
||||
StorageEntity: &StorageEntity{Label: "Patterns", Dir: db.FilePath("patterns"), ItemIsDir: true},
|
||||
SystemPatternFile: "system.md",
|
||||
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
||||
CustomPatternsDir: "", // Will be set after loading .env file
|
||||
}
|
||||
|
||||
db.Sessions = &SessionsEntity{
|
||||
@@ -49,6 +51,18 @@ func (o *Db) Configure() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set custom patterns directory after loading .env file
|
||||
customPatternsDir := os.Getenv("CUSTOM_PATTERNS_DIRECTORY")
|
||||
if customPatternsDir != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(customPatternsDir, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
customPatternsDir = filepath.Join(homeDir, customPatternsDir[2:])
|
||||
}
|
||||
}
|
||||
o.Patterns.CustomPatternsDir = customPatternsDir
|
||||
}
|
||||
|
||||
if err = o.Patterns.Configure(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
@@ -16,6 +17,7 @@ type PatternsEntity struct {
|
||||
*StorageEntity
|
||||
SystemPatternFile string
|
||||
UniquePatternsFilePath string
|
||||
CustomPatternsDir string
|
||||
}
|
||||
|
||||
// Pattern represents a single pattern with its metadata
|
||||
@@ -43,7 +45,7 @@ func (o *PatternsEntity) GetApplyVariables(
|
||||
}
|
||||
|
||||
// Use the resolved absolute path to get the pattern
|
||||
pattern, err = o.getFromFile(absPath)
|
||||
pattern, _ = o.getFromFile(absPath)
|
||||
} else {
|
||||
// Otherwise, get the pattern from the database
|
||||
pattern, err = o.getFromDB(source)
|
||||
@@ -89,6 +91,19 @@ func (o *PatternsEntity) applyVariables(
|
||||
|
||||
// retrieves a pattern from the database by name
|
||||
func (o *PatternsEntity) getFromDB(name string) (ret *Pattern, err error) {
|
||||
// First check custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
customPatternPath := filepath.Join(o.CustomPatternsDir, name, o.SystemPatternFile)
|
||||
if pattern, customErr := os.ReadFile(customPatternPath); customErr == nil {
|
||||
ret = &Pattern{
|
||||
Name: name,
|
||||
Pattern: string(pattern),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to main patterns directory
|
||||
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
||||
|
||||
var pattern []byte
|
||||
@@ -145,6 +160,71 @@ func (o *PatternsEntity) getFromFile(pathStr string) (pattern *Pattern, err erro
|
||||
return
|
||||
}
|
||||
|
||||
// GetNames overrides StorageEntity.GetNames to include custom patterns directory
|
||||
func (o *PatternsEntity) GetNames() (ret []string, err error) {
|
||||
// Get names from main patterns directory
|
||||
mainNames, err := o.StorageEntity.GetNames()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a map to track unique pattern names (custom patterns override main ones)
|
||||
nameMap := make(map[string]bool)
|
||||
for _, name := range mainNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
|
||||
// Get names from custom patterns directory if it exists
|
||||
if o.CustomPatternsDir != "" {
|
||||
// Create a temporary StorageEntity for the custom directory
|
||||
customStorage := &StorageEntity{
|
||||
Dir: o.CustomPatternsDir,
|
||||
ItemIsDir: o.StorageEntity.ItemIsDir,
|
||||
FileExtension: o.StorageEntity.FileExtension,
|
||||
}
|
||||
|
||||
customNames, customErr := customStorage.GetNames()
|
||||
if customErr == nil {
|
||||
// Add custom patterns, they will override main patterns with same name
|
||||
for _, name := range customNames {
|
||||
nameMap[name] = true
|
||||
}
|
||||
}
|
||||
// Ignore errors from custom directory (it might not exist)
|
||||
}
|
||||
|
||||
// Convert map keys back to slice
|
||||
ret = make([]string, 0, len(nameMap))
|
||||
for name := range nameMap {
|
||||
ret = append(ret, name)
|
||||
}
|
||||
|
||||
// Sort the patterns alphabetically
|
||||
sort.Strings(ret)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ListNames overrides StorageEntity.ListNames to use PatternsEntity.GetNames
|
||||
func (o *PatternsEntity) ListNames(shellCompleteList bool) (err error) {
|
||||
var names []string
|
||||
if names, err = o.GetNames(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
if !shellCompleteList {
|
||||
fmt.Printf("\nNo %v\n", o.StorageEntity.Label)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, item := range names {
|
||||
fmt.Printf("%s\n", item)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get required for Storage interface
|
||||
func (o *PatternsEntity) Get(name string) (*Pattern, error) {
|
||||
// Use GetPattern with no variables
|
||||
|
||||
@@ -162,3 +162,123 @@ func TestPatternsEntity_Save(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatterns(t *testing.T) {
|
||||
// Create main patterns directory
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
// Create custom patterns directory
|
||||
customDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(customDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: customDir,
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Create a pattern in custom directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "custom-pattern", "Custom pattern content")
|
||||
|
||||
// Create a pattern with same name in both directories (custom should override)
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Main shared pattern")
|
||||
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: customDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "shared-pattern", "Custom shared pattern")
|
||||
|
||||
// Test GetNames includes both directories
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
assert.Contains(t, names, "custom-pattern")
|
||||
assert.Contains(t, names, "shared-pattern")
|
||||
|
||||
// Test that custom pattern overrides main pattern
|
||||
pattern, err := entity.getFromDB("shared-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom shared pattern", pattern.Pattern)
|
||||
|
||||
// Test that main pattern is accessible when not overridden
|
||||
pattern, err = entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
|
||||
// Test that custom pattern is accessible
|
||||
pattern, err = entity.getFromDB("custom-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
func TestPatternsEntity_CustomPatternsEmpty(t *testing.T) {
|
||||
// Test behavior when custom patterns directory is empty or doesn't exist
|
||||
mainDir, err := os.MkdirTemp("", "test-main-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(mainDir)
|
||||
|
||||
entity := &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
CustomPatternsDir: "/nonexistent/directory",
|
||||
}
|
||||
|
||||
// Create a pattern in main directory
|
||||
createTestPattern(t, &PatternsEntity{
|
||||
StorageEntity: &StorageEntity{
|
||||
Dir: mainDir,
|
||||
Label: "patterns",
|
||||
ItemIsDir: true,
|
||||
},
|
||||
SystemPatternFile: "system.md",
|
||||
}, "main-pattern", "Main pattern content")
|
||||
|
||||
// Test GetNames works even with nonexistent custom directory
|
||||
names, err := entity.GetNames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, names, "main-pattern")
|
||||
|
||||
// Test that main pattern is accessible
|
||||
pattern, err := entity.getFromDB("main-pattern")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Main pattern content", pattern.Pattern)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package fsdb
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type SessionsEntity struct {
|
||||
@@ -38,16 +38,16 @@ func (o *SessionsEntity) SaveSession(session *Session) (err error) {
|
||||
|
||||
type Session struct {
|
||||
Name string
|
||||
Messages []*goopenai.ChatCompletionMessage
|
||||
Messages []*chat.ChatCompletionMessage
|
||||
|
||||
vendorMessages []*goopenai.ChatCompletionMessage
|
||||
vendorMessages []*chat.ChatCompletionMessage
|
||||
}
|
||||
|
||||
func (o *Session) IsEmpty() bool {
|
||||
return len(o.Messages) == 0
|
||||
}
|
||||
|
||||
func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) Append(messages ...*chat.ChatCompletionMessage) {
|
||||
if o.vendorMessages != nil {
|
||||
for _, message := range messages {
|
||||
o.Messages = append(o.Messages, message)
|
||||
@@ -58,7 +58,7 @@ func (o *Session) Append(messages ...*goopenai.ChatCompletionMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) GetVendorMessages() (ret []*chat.ChatCompletionMessage) {
|
||||
if len(o.vendorMessages) == 0 {
|
||||
for _, message := range o.Messages {
|
||||
o.appendVendorMessage(message)
|
||||
@@ -68,13 +68,13 @@ func (o *Session) GetVendorMessages() (ret []*goopenai.ChatCompletionMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
func (o *Session) appendVendorMessage(message *goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) appendVendorMessage(message *chat.ChatCompletionMessage) {
|
||||
if message.Role != common.ChatMessageRoleMeta {
|
||||
o.vendorMessages = append(o.vendorMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Session) GetLastMessage() (ret *goopenai.ChatCompletionMessage) {
|
||||
func (o *Session) GetLastMessage() (ret *chat.ChatCompletionMessage) {
|
||||
if len(o.Messages) > 0 {
|
||||
ret = o.Messages[len(o.Messages)-1]
|
||||
}
|
||||
@@ -86,9 +86,9 @@ func (o *Session) String() (ret string) {
|
||||
ret += fmt.Sprintf("\n--- \n[%v]\n%v", message.Role, message.Content)
|
||||
if message.MultiContent != nil {
|
||||
for _, part := range message.MultiContent {
|
||||
if part.Type == goopenai.ChatMessagePartTypeImageURL {
|
||||
if part.Type == chat.ChatMessagePartTypeImageURL {
|
||||
ret += fmt.Sprintf("\n%v: %v", part.Type, *part.ImageURL)
|
||||
} else if part.Type == goopenai.ChatMessagePartTypeText {
|
||||
} else if part.Type == chat.ChatMessagePartTypeText {
|
||||
ret += fmt.Sprintf("\n%v: %v", part.Type, part.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package fsdb
|
||||
import (
|
||||
"testing"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
)
|
||||
|
||||
func TestSessions_GetOrCreateSession(t *testing.T) {
|
||||
@@ -27,7 +27,7 @@ func TestSessions_SaveSession(t *testing.T) {
|
||||
StorageEntity: &StorageEntity{Dir: dir, FileExtension: ".json"},
|
||||
}
|
||||
sessionName := "testSession"
|
||||
session := &Session{Name: sessionName, Messages: []*goopenai.ChatCompletionMessage{{Content: "message1"}}}
|
||||
session := &Session{Name: sessionName, Messages: []*chat.ChatCompletionMessage{{Content: "message1"}}}
|
||||
err := sessions.SaveSession(session)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
|
||||
@@ -152,7 +152,6 @@ func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseBoolElseFalse(val string) (ret bool) {
|
||||
@@ -279,7 +278,6 @@ func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
for _, setting := range o {
|
||||
setting.FillEnvFileContent(buffer)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type SetupQuestions []*SetupQuestion
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestSysPlugin(t *testing.T) {
|
||||
if !filepath.IsAbs(got) {
|
||||
return fmt.Errorf("expected absolute path, got %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") {
|
||||
if !strings.Contains(got, "home") && !strings.Contains(got, "Users") && got != "/root" {
|
||||
return fmt.Errorf("path %s doesn't look like a home directory", got)
|
||||
}
|
||||
return nil
|
||||
|
||||
66
plugins/tools/custom_patterns/custom_patterns.go
Normal file
66
plugins/tools/custom_patterns/custom_patterns.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
)
|
||||
|
||||
func NewCustomPatterns() (ret *CustomPatterns) {
|
||||
label := "Custom Patterns"
|
||||
ret = &CustomPatterns{}
|
||||
|
||||
ret.PluginBase = &plugins.PluginBase{
|
||||
Name: label,
|
||||
SetupDescription: "Custom Patterns - Set directory for your custom patterns (optional)",
|
||||
EnvNamePrefix: plugins.BuildEnvVariablePrefix(label),
|
||||
ConfigureCustom: ret.configure,
|
||||
}
|
||||
|
||||
ret.CustomPatternsDir = ret.AddSetupQuestionCustom("Directory", false,
|
||||
"Enter the path to your custom patterns directory (leave empty to skip)")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type CustomPatterns struct {
|
||||
*plugins.PluginBase
|
||||
CustomPatternsDir *plugins.SetupQuestion
|
||||
}
|
||||
|
||||
func (o *CustomPatterns) configure() error {
|
||||
if o.CustomPatternsDir.Value != "" {
|
||||
// Expand home directory if needed
|
||||
if strings.HasPrefix(o.CustomPatternsDir.Value, "~/") {
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
o.CustomPatternsDir.Value = filepath.Join(homeDir, o.CustomPatternsDir.Value[2:])
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to absolute path
|
||||
if absPath, err := filepath.Abs(o.CustomPatternsDir.Value); err == nil {
|
||||
o.CustomPatternsDir.Value = absPath
|
||||
}
|
||||
|
||||
// Check if directory exists, create only if it doesn't
|
||||
if _, err := os.Stat(o.CustomPatternsDir.Value); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(o.CustomPatternsDir.Value, 0755); err != nil {
|
||||
// Log the error but don't clear the value - let it persist in env file
|
||||
fmt.Printf("Warning: Could not create custom patterns directory %s: %v\n", o.CustomPatternsDir.Value, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConfigured returns true if a custom patterns directory has been set
|
||||
func (o *CustomPatterns) IsConfigured() bool {
|
||||
// First configure to load values from environment variables
|
||||
o.Configure()
|
||||
// Check if the plugin has been configured with a directory
|
||||
return o.CustomPatternsDir.Value != ""
|
||||
}
|
||||
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
79
plugins/tools/custom_patterns/custom_patterns_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package custom_patterns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCustomPatterns(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
assert.NotNil(t, plugin)
|
||||
assert.Equal(t, "Custom Patterns", plugin.GetName())
|
||||
assert.Equal(t, "Custom Patterns - Set directory for your custom patterns (optional)", plugin.GetSetupDescription())
|
||||
assert.False(t, plugin.IsConfigured()) // Should not be configured initially
|
||||
}
|
||||
func TestCustomPatterns_Configure(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with empty directory (should work)
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
err := plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with home directory expansion
|
||||
plugin.CustomPatternsDir.Value = "~/test-patterns"
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
expectedPath := filepath.Join(homeDir, "test-patterns")
|
||||
absExpected, _ := filepath.Abs(expectedPath)
|
||||
assert.Equal(t, absExpected, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Clean up
|
||||
os.RemoveAll(plugin.CustomPatternsDir.Value)
|
||||
}
|
||||
|
||||
func TestCustomPatterns_ConfigureWithTempDir(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Test with a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "test-custom-patterns-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
plugin.CustomPatternsDir.Value = tmpDir
|
||||
err = plugin.configure()
|
||||
assert.NoError(t, err)
|
||||
|
||||
absPath, _ := filepath.Abs(tmpDir)
|
||||
assert.Equal(t, absPath, plugin.CustomPatternsDir.Value)
|
||||
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(plugin.CustomPatternsDir.Value)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsDir())
|
||||
|
||||
// Should be configured now
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
}
|
||||
|
||||
func TestCustomPatterns_IsConfigured(t *testing.T) {
|
||||
plugin := NewCustomPatterns()
|
||||
|
||||
// Initially not configured
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
|
||||
// Set a directory
|
||||
plugin.CustomPatternsDir.Value = "/some/path"
|
||||
assert.True(t, plugin.IsConfigured())
|
||||
|
||||
// Clear the directory
|
||||
plugin.CustomPatternsDir.Value = ""
|
||||
assert.False(t, plugin.IsConfigured())
|
||||
}
|
||||
@@ -3,14 +3,13 @@ package restapi
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/core"
|
||||
@@ -95,7 +94,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) {
|
||||
// Load and prepend strategy prompt if strategyName is set
|
||||
if p.StrategyName != "" {
|
||||
strategyFile := filepath.Join(os.Getenv("HOME"), ".config", "fabric", "strategies", p.StrategyName+".json")
|
||||
data, err := ioutil.ReadFile(strategyFile)
|
||||
data, err := os.ReadFile(strategyFile)
|
||||
if err == nil {
|
||||
var s struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -115,7 +114,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) {
|
||||
|
||||
// Pass the language received in the initial request to the common.ChatRequest
|
||||
chatReq := &common.ChatRequest{
|
||||
Message: &goopenai.ChatCompletionMessage{
|
||||
Message: &chat.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: p.UserInput,
|
||||
},
|
||||
|
||||
@@ -57,17 +57,18 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -80,17 +81,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -99,17 +101,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
@@ -103,7 +103,6 @@ func ServeOllama(registry *core.PluginRegistry, address string, version string)
|
||||
r.GET("/api/tags", typeConversion.ollamaTags)
|
||||
r.GET("/api/version", func(c *gin.Context) {
|
||||
c.Data(200, "application/json", []byte(fmt.Sprintf("{\"%s\"}", version)))
|
||||
return
|
||||
})
|
||||
r.POST("/api/chat", typeConversion.ollamaChat)
|
||||
|
||||
@@ -262,15 +261,10 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
for _, bytein := range marshalled {
|
||||
res = append(res, bytein)
|
||||
}
|
||||
for _, bytebreak := range []byte("\n") {
|
||||
res = append(res, bytebreak)
|
||||
}
|
||||
res = append(res, marshalled...)
|
||||
res = append(res, '\n')
|
||||
}
|
||||
c.Data(200, "application/json", res)
|
||||
|
||||
//c.JSON(200, forwardedResponse)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package restapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -23,7 +22,7 @@ func NewStrategiesHandler(r *gin.Engine) {
|
||||
r.GET("/strategies", func(c *gin.Context) {
|
||||
strategiesDir := filepath.Join(os.Getenv("HOME"), ".config", "fabric", "strategies")
|
||||
|
||||
files, err := ioutil.ReadDir(strategiesDir)
|
||||
files, err := os.ReadDir(strategiesDir)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read strategies directory"})
|
||||
return
|
||||
@@ -37,7 +36,7 @@ func NewStrategiesHandler(r *gin.Engine) {
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(strategiesDir, file.Name())
|
||||
data, err := ioutil.ReadFile(fullPath)
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.220"
|
||||
var version = "v1.4.235"
|
||||
|
||||
@@ -1710,7 +1710,7 @@
|
||||
},
|
||||
{
|
||||
"patternName": "analyze_bill_short",
|
||||
"description": "Consended - Analyze a legislative bill and implications.",
|
||||
"description": "Condensed - Analyze a legislative bill and implications.",
|
||||
"tags": [
|
||||
"ANALYSIS",
|
||||
"BILL"
|
||||
@@ -1815,6 +1815,35 @@
|
||||
"WRITING",
|
||||
"CREATIVITY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "extract_alpha",
|
||||
"description": "Extracts the most novel and surprising ideas (\"alpha\") from content, inspired by information theory.",
|
||||
"tags": [
|
||||
"EXTRACT",
|
||||
"ANALYSIS",
|
||||
"CR THINKING",
|
||||
"WISDOM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "extract_mcp_servers",
|
||||
"description": "Analyzes content to identify and extract detailed information about Model Context Protocol (MCP) servers.",
|
||||
"tags": [
|
||||
"ANALYSIS",
|
||||
"EXTRACT",
|
||||
"DEVELOPMENT",
|
||||
"AI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"patternName": "review_code",
|
||||
"description": "Performs a comprehensive code review, providing detailed feedback on correctness, security, and performance.",
|
||||
"tags": [
|
||||
"DEVELOPMENT",
|
||||
"REVIEW",
|
||||
"SECURITY"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user