mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-04-24 03:00:15 -04:00
added persistance
This commit is contained in:
@@ -118,7 +118,10 @@ def main():
|
||||
print("No patterns found")
|
||||
sys.exit()
|
||||
if args.listmodels:
|
||||
standalone.fetch_available_models()
|
||||
setup = Setup()
|
||||
allmodels = setup.fetch_available_models()
|
||||
for model in allmodels:
|
||||
print(model)
|
||||
sys.exit()
|
||||
if args.text is not None:
|
||||
text = args.text
|
||||
|
||||
@@ -8,6 +8,7 @@ import platform
|
||||
from dotenv import load_dotenv
|
||||
import zipfile
|
||||
import tempfile
|
||||
import re
|
||||
import shutil
|
||||
|
||||
current_directory = os.path.dirname(os.path.realpath(__file__))
|
||||
@@ -424,17 +425,24 @@ class Setup:
|
||||
self.gptlist = []
|
||||
self.fullOllamaList = []
|
||||
self.claudeList = ['claude-3-opus-20240229']
|
||||
load_dotenv(self.env_file)
|
||||
try:
|
||||
openaiapikey = os.environ["OPENAI_API_KEY"]
|
||||
self.openaiapi_key = openaiapikey
|
||||
except KeyError:
|
||||
print("OPENAI_API_KEY not found in environment variables.")
|
||||
sys.exit()
|
||||
self.fetch_available_models()
|
||||
|
||||
def fetch_available_models(self):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.client.api_key}"
|
||||
"Authorization": f"Bearer {self.openaiapi_key}"
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
"https://api.openai.com/v1/models", headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("OpenAI GPT models:\n")
|
||||
models = response.json().get("data", [])
|
||||
# Filter only gpt models
|
||||
gpt_models = [model for model in models if model.get(
|
||||
@@ -444,18 +452,19 @@ class Setup:
|
||||
gpt_models, key=lambda x: x.get("id"))
|
||||
|
||||
for model in sorted_gpt_models:
|
||||
print(model.get("id"))
|
||||
self.gptlist.append(model.get("id"))
|
||||
print("\nLocal Ollama models:")
|
||||
import ollama
|
||||
default_modelollamaList = ollama.list()['models']
|
||||
for model in ollamaList:
|
||||
print(model['name'].rstrip(":latest"))
|
||||
self.fullOllamaList.append(model['name'].rstrip(":latest"))
|
||||
print("\nClaude models:")
|
||||
print("claude-3-opus-20240229")
|
||||
else:
|
||||
print(f"Failed to fetch models: HTTP {response.status_code}")
|
||||
sys.exit()
|
||||
import ollama
|
||||
try:
|
||||
default_modelollamaList = ollama.list()['models']
|
||||
for model in default_modelollamaList:
|
||||
self.fullOllamaList.append(model['name'].rstrip(":latest"))
|
||||
except:
|
||||
self.fullOllamaList = []
|
||||
allmodels = self.gptlist + self.fullOllamaList + self.claudeList
|
||||
return allmodels
|
||||
|
||||
def api_key(self, api_key):
|
||||
""" Set the OpenAI API key in the environment file.
|
||||
@@ -509,36 +518,69 @@ class Setup:
|
||||
with open(self.env_file, "w") as f:
|
||||
f.write(f"CLAUDE_API_KEY={claude_key}")
|
||||
|
||||
def update_fabric_command(self, line, model):
|
||||
fabric_command_regex = re.compile(
|
||||
r"(fabric --pattern\s+\S+.*?)( --claude| --local)?'")
|
||||
match = fabric_command_regex.search(line)
|
||||
if match:
|
||||
base_command = match.group(1)
|
||||
# Provide a default value for current_flag
|
||||
current_flag = match.group(2) if match.group(2) else ""
|
||||
new_flag = ""
|
||||
if model in self.claudeList:
|
||||
new_flag = " --claude"
|
||||
elif model in self.fullOllamaList:
|
||||
new_flag = " --local"
|
||||
# Update the command if the new flag is different or to remove an existing flag.
|
||||
# Ensure to add the closing quote that was part of the original regex
|
||||
return f"{base_command}{new_flag}'\n"
|
||||
else:
|
||||
return line # Return the line unmodified if no match is found.
|
||||
|
||||
def update_fabric_alias(self, line, model):
|
||||
fabric_alias_regex = re.compile(
|
||||
r"(alias fabric='[^']+?)( --claude| --local)?'")
|
||||
match = fabric_alias_regex.search(line)
|
||||
if match:
|
||||
base_command, current_flag = match.groups()
|
||||
new_flag = ""
|
||||
if model in self.claudeList:
|
||||
new_flag = " --claude"
|
||||
elif model in self.fullOllamaList:
|
||||
new_flag = " --local"
|
||||
# Update the alias if the new flag is different or to remove an existing flag.
|
||||
return f"{base_command}{new_flag}'\n"
|
||||
else:
|
||||
return line # Return the line unmodified if no match is found.
|
||||
|
||||
def default_model(self, model):
|
||||
""" Set the default model in the environment file.
|
||||
"""Set the default model in the environment file.
|
||||
|
||||
Args:
|
||||
model (str): The model to be set.
|
||||
"""
|
||||
|
||||
model = model.strip()
|
||||
if os.path.exists(self.env_file) and model:
|
||||
with open(self.env_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
with open(self.env_file, "w") as f:
|
||||
for line in lines:
|
||||
if "DEFAULT_MODEL" not in line:
|
||||
f.write(line)
|
||||
f.write(f"DEFAULT_MODEL={model}")
|
||||
elif model:
|
||||
with open(self.env_file, "w") as f:
|
||||
f.write(f"DEFAULT_MODEL={model}")
|
||||
else:
|
||||
with open(self.env_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
with open(self.env_file, "w") as f:
|
||||
for line in lines:
|
||||
if "DEFAULT_MODEL" not in line:
|
||||
f.write(line)
|
||||
import re
|
||||
plain_fabric_regex = re.compile(
|
||||
r"(fabric='.*fabric)( --claude| --local)?'"
|
||||
fabric_regex = re.compile(r"(fabric --pattern.*)( --claude|--local)'")
|
||||
if model:
|
||||
# Write or update the DEFAULT_MODEL in env_file
|
||||
if os.path.exists(self.env_file):
|
||||
with open(self.env_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
with open(self.env_file, "w") as f:
|
||||
found = False
|
||||
for line in lines:
|
||||
if line.startswith("DEFAULT_MODEL"):
|
||||
f.write(f"DEFAULT_MODEL={model}\n")
|
||||
found = True
|
||||
else:
|
||||
f.write(line)
|
||||
if not found:
|
||||
f.write(f"DEFAULT_MODEL={model}\n")
|
||||
else:
|
||||
with open(self.env_file, "w") as f:
|
||||
f.write(f"DEFAULT_MODEL={model}\n")
|
||||
|
||||
# Compile regular expressions outside of the loop for efficiency
|
||||
|
||||
user_home = os.path.expanduser("~")
|
||||
sh_config = None
|
||||
# Check for shell configuration files
|
||||
@@ -552,17 +594,14 @@ class Setup:
|
||||
lines = f.readlines()
|
||||
with open(sh_config, "w") as f:
|
||||
for line in lines:
|
||||
# Remove existing --claude or --local
|
||||
modified_line = re.sub(fabric_regex, r"\1'", line)
|
||||
modified_line = line
|
||||
# Update existing fabric commands
|
||||
if "fabric --pattern" in line:
|
||||
if model in self.claudeList:
|
||||
whole_thing = plain_fabric_regex.search(line)[0]
|
||||
beginning_match = plain_fabric_regex.search(line)[1]
|
||||
modified_line = re.sub(
|
||||
fabric_regex, r"\1 --claude'", line)
|
||||
elif model in self.fullOllamaList:
|
||||
modified_line = re.sub(
|
||||
fabric_regex, r"\1 --local'", line)
|
||||
modified_line = self.update_fabric_command(
|
||||
modified_line, model)
|
||||
elif "fabric=" in line:
|
||||
modified_line = self.update_fabric_alias(
|
||||
modified_line, model)
|
||||
f.write(modified_line)
|
||||
print(f"""Default model changed to {
|
||||
model}. Please restart your terminal to use it.""")
|
||||
|
||||
Reference in New Issue
Block a user