mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-31 09:58:19 -05:00
Merge branch 'master' into patch-2
This commit is contained in:
@@ -7,6 +7,7 @@ agents = {} # key, (task, full_message_history, model)
|
||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||
|
||||
def create_agent(task, prompt, model):
|
||||
"""Create a new agent and return its key"""
|
||||
global next_key
|
||||
global agents
|
||||
|
||||
@@ -32,6 +33,7 @@ def create_agent(task, prompt, model):
|
||||
|
||||
|
||||
def message_agent(key, message):
|
||||
"""Send a message to an agent and return its response"""
|
||||
global agents
|
||||
|
||||
task, messages, model = agents[int(key)]
|
||||
@@ -52,6 +54,7 @@ def message_agent(key, message):
|
||||
|
||||
|
||||
def list_agents():
|
||||
"""Return a list of all agents"""
|
||||
global agents
|
||||
|
||||
# Return a list of agent keys and their tasks
|
||||
@@ -59,6 +62,7 @@ def list_agents():
|
||||
|
||||
|
||||
def delete_agent(key):
|
||||
"""Delete an agent and return True if successful, False otherwise"""
|
||||
global agents
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import yaml
|
||||
import data
|
||||
|
||||
import os
|
||||
|
||||
class AIConfig:
|
||||
"""Class to store the AI's name, role, and goals."""
|
||||
def __init__(self, ai_name="", ai_role="", ai_goals=[]):
|
||||
"""Initialize the AIConfig class"""
|
||||
self.ai_name = ai_name
|
||||
self.ai_role = ai_role
|
||||
self.ai_goals = ai_goals
|
||||
|
||||
# Soon this will go in a folder where it remembers more stuff about the run(s)
|
||||
SAVE_FILE = "../ai_settings.yaml"
|
||||
SAVE_FILE = os.path.join(os.path.dirname(__file__), '..', 'ai_settings.yaml')
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_file=SAVE_FILE):
|
||||
# Load variables from yaml file if it exists
|
||||
"""Load variables from yaml file if it exists, otherwise use defaults."""
|
||||
try:
|
||||
with open(config_file) as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader)
|
||||
@@ -27,11 +29,14 @@ class AIConfig:
|
||||
return cls(ai_name, ai_role, ai_goals)
|
||||
|
||||
def save(self, config_file=SAVE_FILE):
|
||||
"""Save variables to yaml file."""
|
||||
config = {"ai_name": self.ai_name, "ai_role": self.ai_role, "ai_goals": self.ai_goals}
|
||||
with open(config_file, "w") as file:
|
||||
yaml.dump(config, file)
|
||||
|
||||
|
||||
def construct_full_prompt(self):
|
||||
"""Construct the full prompt for the AI to use."""
|
||||
prompt_start = """Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||
|
||||
# Construct full prompt
|
||||
|
||||
@@ -6,8 +6,8 @@ from json_parser import fix_and_parse_json
|
||||
cfg = Config()
|
||||
|
||||
# Evaluating code
|
||||
|
||||
def evaluate_code(code: str) -> List[str]:
|
||||
"""Evaluates the given code and returns a list of suggestions for improvements."""
|
||||
function_string = "def analyze_code(code: str) -> List[str]:"
|
||||
args = [code]
|
||||
description_string = """Analyzes the given code and returns a list of suggestions for improvements."""
|
||||
@@ -18,8 +18,8 @@ def evaluate_code(code: str) -> List[str]:
|
||||
|
||||
|
||||
# Improving code
|
||||
|
||||
def improve_code(suggestions: List[str], code: str) -> str:
|
||||
"""Improves the provided code based on the suggestions provided, making no other changes."""
|
||||
function_string = (
|
||||
"def generate_improved_code(suggestions: List[str], code: str) -> str:"
|
||||
)
|
||||
@@ -31,9 +31,8 @@ def improve_code(suggestions: List[str], code: str) -> str:
|
||||
|
||||
|
||||
# Writing tests
|
||||
|
||||
|
||||
def write_tests(code: str, focus: List[str]) -> str:
|
||||
"""Generates test cases for the existing code, focusing on specific areas if required."""
|
||||
function_string = (
|
||||
"def create_test_cases(code: str, focus: Optional[str] = None) -> str:"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,15 @@ from llm_utils import create_chat_completion
|
||||
cfg = Config()
|
||||
|
||||
def scrape_text(url):
|
||||
response = requests.get(url, headers=cfg.user_agent_header)
|
||||
"""Scrape text from a webpage"""
|
||||
# Most basic check if the URL is valid:
|
||||
if not url.startswith('http'):
|
||||
return "Error: Invalid URL"
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=cfg.user_agent_header)
|
||||
except requests.exceptions.RequestException as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
# Check if the response contains an HTTP error
|
||||
if response.status_code >= 400:
|
||||
@@ -26,6 +34,7 @@ def scrape_text(url):
|
||||
|
||||
|
||||
def extract_hyperlinks(soup):
|
||||
"""Extract hyperlinks from a BeautifulSoup object"""
|
||||
hyperlinks = []
|
||||
for link in soup.find_all('a', href=True):
|
||||
hyperlinks.append((link.text, link['href']))
|
||||
@@ -33,6 +42,7 @@ def extract_hyperlinks(soup):
|
||||
|
||||
|
||||
def format_hyperlinks(hyperlinks):
|
||||
"""Format hyperlinks into a list of strings"""
|
||||
formatted_links = []
|
||||
for link_text, link_url in hyperlinks:
|
||||
formatted_links.append(f"{link_text} ({link_url})")
|
||||
@@ -40,6 +50,7 @@ def format_hyperlinks(hyperlinks):
|
||||
|
||||
|
||||
def scrape_links(url):
|
||||
"""Scrape links from a webpage"""
|
||||
response = requests.get(url, headers=cfg.user_agent_header)
|
||||
|
||||
# Check if the response contains an HTTP error
|
||||
@@ -57,6 +68,7 @@ def scrape_links(url):
|
||||
|
||||
|
||||
def split_text(text, max_length=8192):
|
||||
"""Split text into chunks of a maximum length"""
|
||||
paragraphs = text.split("\n")
|
||||
current_length = 0
|
||||
current_chunk = []
|
||||
@@ -75,12 +87,14 @@ def split_text(text, max_length=8192):
|
||||
|
||||
|
||||
def create_message(chunk, question):
|
||||
"""Create a message for the user to summarize a chunk of text"""
|
||||
return {
|
||||
"role": "user",
|
||||
"content": f"\"\"\"{chunk}\"\"\" Using the above text, please answer the following question: \"{question}\" -- if the question cannot be answered using the text, please summarize the text."
|
||||
}
|
||||
|
||||
def summarize_text(text, question):
|
||||
"""Summarize text using the LLM model"""
|
||||
if not text:
|
||||
return "Error: No text to summarize"
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from config import Config
|
||||
|
||||
cfg = Config()
|
||||
|
||||
from llm_utils import create_chat_completion
|
||||
|
||||
# This is a magic function that can do anything with no-code. See
|
||||
# https://github.com/Torantulino/AI-Functions for more info.
|
||||
def call_ai_function(function, args, description, model=cfg.smart_llm_model):
|
||||
def call_ai_function(function, args, description, model=None):
|
||||
"""Call an AI function"""
|
||||
if model is None:
|
||||
model = cfg.smart_llm_model
|
||||
# For each arg, if any are None, convert to "None":
|
||||
args = [str(arg) if arg is not None else "None" for arg in args]
|
||||
# parse args to comma seperated string
|
||||
|
||||
@@ -3,11 +3,9 @@ import openai
|
||||
from dotenv import load_dotenv
|
||||
from config import Config
|
||||
import token_counter
|
||||
|
||||
cfg = Config()
|
||||
|
||||
from llm_utils import create_chat_completion
|
||||
|
||||
cfg = Config()
|
||||
|
||||
def create_chat_message(role, content):
|
||||
"""
|
||||
@@ -26,8 +24,11 @@ def create_chat_message(role, content):
|
||||
def generate_context(prompt, relevant_memory, full_message_history, model):
|
||||
current_context = [
|
||||
create_chat_message(
|
||||
"system", prompt), create_chat_message(
|
||||
"system", f"Permanent memory: {relevant_memory}")]
|
||||
"system", prompt),
|
||||
create_chat_message(
|
||||
"system", f"The current time and date is {time.strftime('%c')}"),
|
||||
create_chat_message(
|
||||
"system", f"This reminds you of these events from your past:\n{relevant_memory}\n\n")]
|
||||
|
||||
# Add messages from the full message history until we reach the token limit
|
||||
next_message_to_add_index = len(full_message_history) - 1
|
||||
@@ -45,6 +46,7 @@ def chat_with_ai(
|
||||
permanent_memory,
|
||||
token_limit,
|
||||
debug=False):
|
||||
"""Interact with the OpenAI API, sending the prompt, user input, message history, and permanent memory."""
|
||||
while True:
|
||||
try:
|
||||
"""
|
||||
@@ -63,7 +65,7 @@ def chat_with_ai(
|
||||
model = cfg.fast_llm_model # TODO: Change model from hardcode to argument
|
||||
# Reserve 1000 tokens for the response
|
||||
if debug:
|
||||
print(f"Token limit: {token_limit}")
|
||||
print(f"Token limit: {token_limit}")
|
||||
send_token_limit = token_limit - 1000
|
||||
|
||||
relevant_memory = permanent_memory.get_relevant(str(full_message_history[-5:]), 10)
|
||||
@@ -95,7 +97,7 @@ def chat_with_ai(
|
||||
|
||||
# Count the currently used tokens
|
||||
current_tokens_used += tokens_to_add
|
||||
|
||||
|
||||
# Move to the next most recent message in the full message history
|
||||
next_message_to_add_index -= 1
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import browse
|
||||
import json
|
||||
from memory import PineconeMemory
|
||||
from memory import get_memory
|
||||
import datetime
|
||||
import agent_manager as agents
|
||||
import speak
|
||||
@@ -9,6 +9,7 @@ import ai_functions as ai
|
||||
from file_operations import read_file, write_to_file, append_to_file, delete_file, search_files
|
||||
from execute_code import execute_python_file
|
||||
from json_parser import fix_and_parse_json
|
||||
from image_gen import generate_image
|
||||
from duckduckgo_search import ddg
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
@@ -24,6 +25,7 @@ def is_valid_int(value):
|
||||
return False
|
||||
|
||||
def get_command(response):
|
||||
"""Parse the response and return the command name and arguments"""
|
||||
try:
|
||||
response_json = fix_and_parse_json(response)
|
||||
|
||||
@@ -52,10 +54,12 @@ def get_command(response):
|
||||
|
||||
|
||||
def execute_command(command_name, arguments):
|
||||
memory = PineconeMemory()
|
||||
"""Execute the command and return the result"""
|
||||
memory = get_memory(cfg)
|
||||
|
||||
try:
|
||||
if command_name == "google":
|
||||
|
||||
|
||||
# Check if the Google API key is set and use the official search method
|
||||
# If the API key is not set or has only whitespaces, use the unofficial search method
|
||||
if cfg.google_api_key and (cfg.google_api_key.strip() if cfg.google_api_key else None):
|
||||
@@ -102,21 +106,27 @@ def execute_command(command_name, arguments):
|
||||
return ai.write_tests(arguments["code"], arguments.get("focus"))
|
||||
elif command_name == "execute_python_file": # Add this command
|
||||
return execute_python_file(arguments["file"])
|
||||
elif command_name == "generate_image":
|
||||
return generate_image(arguments["prompt"])
|
||||
elif command_name == "do_nothing":
|
||||
return "No action performed."
|
||||
elif command_name == "task_complete":
|
||||
shutdown()
|
||||
else:
|
||||
return f"Unknown command {command_name}"
|
||||
return f"Unknown command '{command_name}'. Please refer to the 'COMMANDS' list for availabe commands and only respond in the specified JSON format."
|
||||
# All errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
|
||||
def get_datetime():
|
||||
"""Return the current date and time"""
|
||||
return "Current date and time: " + \
|
||||
datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def google_search(query, num_results=8):
|
||||
"""Return the results of a google search"""
|
||||
search_results = []
|
||||
for j in ddg(query, max_results=num_results):
|
||||
search_results.append(j)
|
||||
@@ -124,6 +134,7 @@ def google_search(query, num_results=8):
|
||||
return json.dumps(search_results, ensure_ascii=False, indent=4)
|
||||
|
||||
def google_official_search(query, num_results=8):
|
||||
"""Return the results of a google search using the official Google API"""
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
import json
|
||||
@@ -159,6 +170,7 @@ def google_official_search(query, num_results=8):
|
||||
return search_results_links
|
||||
|
||||
def browse_website(url, question):
|
||||
"""Browse a website and return the summary and links"""
|
||||
summary = get_text_summary(url, question)
|
||||
links = get_hyperlinks(url)
|
||||
|
||||
@@ -172,23 +184,27 @@ def browse_website(url, question):
|
||||
|
||||
|
||||
def get_text_summary(url, question):
|
||||
"""Return the results of a google search"""
|
||||
text = browse.scrape_text(url)
|
||||
summary = browse.summarize_text(text, question)
|
||||
return """ "Result" : """ + summary
|
||||
|
||||
|
||||
def get_hyperlinks(url):
|
||||
"""Return the results of a google search"""
|
||||
link_list = browse.scrape_links(url)
|
||||
return link_list
|
||||
|
||||
|
||||
def commit_memory(string):
|
||||
"""Commit a string to memory"""
|
||||
_text = f"""Committing memory with string "{string}" """
|
||||
mem.permanent_memory.append(string)
|
||||
return _text
|
||||
|
||||
|
||||
def delete_memory(key):
|
||||
"""Delete a memory with a given key"""
|
||||
if key >= 0 and key < len(mem.permanent_memory):
|
||||
_text = "Deleting memory with key " + str(key)
|
||||
del mem.permanent_memory[key]
|
||||
@@ -200,6 +216,7 @@ def delete_memory(key):
|
||||
|
||||
|
||||
def overwrite_memory(key, string):
|
||||
"""Overwrite a memory with a given key and string"""
|
||||
# Check if the key is a valid integer
|
||||
if is_valid_int(key):
|
||||
key_int = int(key)
|
||||
@@ -226,11 +243,13 @@ def overwrite_memory(key, string):
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""Shut down the program"""
|
||||
print("Shutting down...")
|
||||
quit()
|
||||
|
||||
|
||||
def start_agent(name, task, prompt, model=cfg.fast_llm_model):
|
||||
"""Start an agent with a given name, task, and prompt"""
|
||||
global cfg
|
||||
|
||||
# Remove underscores from name
|
||||
@@ -254,6 +273,7 @@ def start_agent(name, task, prompt, model=cfg.fast_llm_model):
|
||||
|
||||
|
||||
def message_agent(key, message):
|
||||
"""Message an agent with a given key and message"""
|
||||
global cfg
|
||||
|
||||
# Check if the key is a valid integer
|
||||
@@ -272,11 +292,13 @@ def message_agent(key, message):
|
||||
|
||||
|
||||
def list_agents():
|
||||
"""List all agents"""
|
||||
return agents.list_agents()
|
||||
|
||||
|
||||
def delete_agent(key):
|
||||
"""Delete an agent with a given key"""
|
||||
result = agents.delete_agent(key)
|
||||
if not result:
|
||||
return f"Agent {key} does not exist."
|
||||
return f"Agent {key} deleted."
|
||||
return f"Agent {key} deleted."
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import abc
|
||||
import os
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
@@ -5,7 +6,7 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
class Singleton(abc.ABCMeta, type):
|
||||
"""
|
||||
Singleton metaclass for ensuring only one instance of a class.
|
||||
"""
|
||||
@@ -13,6 +14,7 @@ class Singleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""Call method for the singleton metaclass."""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(
|
||||
Singleton, cls).__call__(
|
||||
@@ -20,12 +22,18 @@ class Singleton(type):
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
||||
pass
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""
|
||||
Configuration class to store the state of bools for different scripts access.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Config class"""
|
||||
self.debug = False
|
||||
self.continuous_mode = False
|
||||
self.speak_mode = False
|
||||
# TODO - make these models be self-contained, using langchain, so we can configure them once and call it good
|
||||
@@ -53,45 +61,71 @@ class Config(metaclass=Singleton):
|
||||
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
||||
self.pinecone_region = os.getenv("PINECONE_ENV")
|
||||
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||
|
||||
# User agent headers to use when browsing web
|
||||
# Some websites might just completely deny request with an error code if no user agent was found.
|
||||
self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"}
|
||||
self.redis_host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.redis_port = os.getenv("REDIS_PORT", "6379")
|
||||
self.redis_password = os.getenv("REDIS_PASSWORD", "")
|
||||
self.wipe_redis_on_start = os.getenv("WIPE_REDIS_ON_START", "True") == 'True'
|
||||
self.memory_index = os.getenv("MEMORY_INDEX", 'auto-gpt')
|
||||
# Note that indexes must be created on db 0 in redis, this is not configureable.
|
||||
|
||||
self.memory_backend = os.getenv("MEMORY_BACKEND", 'local')
|
||||
# Initialize the OpenAI API client
|
||||
openai.api_key = self.openai_api_key
|
||||
|
||||
def set_continuous_mode(self, value: bool):
|
||||
"""Set the continuous mode value."""
|
||||
self.continuous_mode = value
|
||||
|
||||
def set_speak_mode(self, value: bool):
|
||||
"""Set the speak mode value."""
|
||||
self.speak_mode = value
|
||||
|
||||
def set_fast_llm_model(self, value: str):
|
||||
"""Set the fast LLM model value."""
|
||||
self.fast_llm_model = value
|
||||
|
||||
def set_smart_llm_model(self, value: str):
|
||||
"""Set the smart LLM model value."""
|
||||
self.smart_llm_model = value
|
||||
|
||||
def set_fast_token_limit(self, value: int):
|
||||
"""Set the fast token limit value."""
|
||||
self.fast_token_limit = value
|
||||
|
||||
def set_smart_token_limit(self, value: int):
|
||||
"""Set the smart token limit value."""
|
||||
self.smart_token_limit = value
|
||||
|
||||
def set_openai_api_key(self, value: str):
|
||||
"""Set the OpenAI API key value."""
|
||||
self.openai_api_key = value
|
||||
|
||||
|
||||
def set_elevenlabs_api_key(self, value: str):
|
||||
"""Set the ElevenLabs API key value."""
|
||||
self.elevenlabs_api_key = value
|
||||
|
||||
|
||||
def set_google_api_key(self, value: str):
|
||||
"""Set the Google API key value."""
|
||||
self.google_api_key = value
|
||||
|
||||
|
||||
def set_custom_search_engine_id(self, value: str):
|
||||
"""Set the custom search engine id value."""
|
||||
self.custom_search_engine_id = value
|
||||
|
||||
def set_pinecone_api_key(self, value: str):
|
||||
"""Set the Pinecone API key value."""
|
||||
self.pinecone_api_key = value
|
||||
|
||||
def set_pinecone_region(self, value: str):
|
||||
"""Set the Pinecone region value."""
|
||||
self.pinecone_region = value
|
||||
|
||||
def set_debug_mode(self, value: bool):
|
||||
"""Set the debug mode value."""
|
||||
self.debug = value
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
def load_prompt():
|
||||
"""Load the prompt from data/prompt.txt"""
|
||||
try:
|
||||
# get directory of this file:
|
||||
file_dir = Path(__file__).parent
|
||||
|
||||
@@ -18,11 +18,13 @@ COMMANDS:
|
||||
12. Append to file: "append_to_file", args: "file": "<file>", "text": "<text>"
|
||||
13. Delete file: "delete_file", args: "file": "<file>"
|
||||
14. Search Files: "search_files", args: "directory": "<directory>"
|
||||
15. Evaluate Code: "evaluate_code", args: "code": "<full _code_string>"
|
||||
15. Evaluate Code: "evaluate_code", args: "code": "<full_code_string>"
|
||||
16. Get Improved Code: "improve_code", args: "suggestions": "<list_of_suggestions>", "code": "<full_code_string>"
|
||||
17. Write Tests: "write_tests", args: "code": "<full_code_string>", "focus": "<list_of_focus_areas>"
|
||||
18. Execute Python File: "execute_python_file", args: "file": "<file>"
|
||||
19. Task Complete (Shutdown): "task_complete", args: "reason": "<reason>"
|
||||
20. Generate Image: "generate_image", args: "prompt": "<prompt>"
|
||||
21. Do Nothing: "do_nothing", args: ""
|
||||
|
||||
RESOURCES:
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
|
||||
def execute_python_file(file):
|
||||
"""Execute a Python file in a Docker container and return the output"""
|
||||
workspace_folder = "auto_gpt_workspace"
|
||||
|
||||
print (f"Executing file '{file}' in workspace '{workspace_folder}'")
|
||||
|
||||
@@ -4,11 +4,13 @@ import os.path
|
||||
# Set a dedicated folder for file I/O
|
||||
working_directory = "auto_gpt_workspace"
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(working_directory):
|
||||
os.makedirs(working_directory)
|
||||
|
||||
|
||||
def safe_join(base, *paths):
|
||||
"""Join one or more path components intelligently."""
|
||||
new_path = os.path.join(base, *paths)
|
||||
norm_new_path = os.path.normpath(new_path)
|
||||
|
||||
@@ -19,6 +21,7 @@ def safe_join(base, *paths):
|
||||
|
||||
|
||||
def read_file(filename):
|
||||
"""Read a file and return the contents"""
|
||||
try:
|
||||
filepath = safe_join(working_directory, filename)
|
||||
with open(filepath, "r") as f:
|
||||
@@ -29,6 +32,7 @@ def read_file(filename):
|
||||
|
||||
|
||||
def write_to_file(filename, text):
|
||||
"""Write text to a file"""
|
||||
try:
|
||||
filepath = safe_join(working_directory, filename)
|
||||
directory = os.path.dirname(filepath)
|
||||
@@ -42,6 +46,7 @@ def write_to_file(filename, text):
|
||||
|
||||
|
||||
def append_to_file(filename, text):
|
||||
"""Append text to a file"""
|
||||
try:
|
||||
filepath = safe_join(working_directory, filename)
|
||||
with open(filepath, "a") as f:
|
||||
@@ -52,6 +57,7 @@ def append_to_file(filename, text):
|
||||
|
||||
|
||||
def delete_file(filename):
|
||||
"""Delete a file"""
|
||||
try:
|
||||
filepath = safe_join(working_directory, filename)
|
||||
os.remove(filepath)
|
||||
|
||||
57
scripts/image_gen.py
Normal file
57
scripts/image_gen.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import requests
|
||||
import io
|
||||
import os.path
|
||||
from PIL import Image
|
||||
from config import Config
|
||||
import uuid
|
||||
import openai
|
||||
from base64 import b64decode
|
||||
|
||||
cfg = Config()
|
||||
|
||||
working_directory = "auto_gpt_workspace"
|
||||
|
||||
def generate_image(prompt):
|
||||
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
|
||||
# DALL-E
|
||||
if cfg.image_provider == 'dalle':
|
||||
|
||||
openai.api_key = cfg.openai_api_key
|
||||
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size="256x256",
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
print("Image Generated for prompt:" + prompt)
|
||||
|
||||
image_data = b64decode(response["data"][0]["b64_json"])
|
||||
|
||||
with open(working_directory + "/" + filename, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return "Saved to disk:" + filename
|
||||
|
||||
# STABLE DIFFUSION
|
||||
elif cfg.image_provider == 'sd':
|
||||
|
||||
API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
|
||||
headers = {"Authorization": "Bearer " + cfg.huggingface_api_token}
|
||||
|
||||
response = requests.post(API_URL, headers=headers, json={
|
||||
"inputs": prompt,
|
||||
})
|
||||
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
print("Image Generated for prompt:" + prompt)
|
||||
|
||||
image.save(os.path.join(working_directory, filename))
|
||||
|
||||
return "Saved to disk:" + filename
|
||||
|
||||
else:
|
||||
return "No Image Provider Set"
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
from typing import Any, Dict, Union
|
||||
from call_ai_function import call_ai_function
|
||||
from config import Config
|
||||
from json_utils import correct_json
|
||||
|
||||
cfg = Config()
|
||||
|
||||
def fix_and_parse_json(json_str: str, try_to_fix_with_gpt: bool = True):
|
||||
json_schema = """
|
||||
{
|
||||
JSON_SCHEMA = """
|
||||
{
|
||||
"command": {
|
||||
"name": "command name",
|
||||
"args":{
|
||||
@@ -20,44 +22,70 @@ def fix_and_parse_json(json_str: str, try_to_fix_with_gpt: bool = True):
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user"
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def fix_and_parse_json(
|
||||
json_str: str,
|
||||
try_to_fix_with_gpt: bool = True
|
||||
) -> Union[str, Dict[Any, Any]]:
|
||||
"""Fix and parse JSON string"""
|
||||
try:
|
||||
json_str = json_str.replace('\t', '')
|
||||
return json.loads(json_str)
|
||||
except Exception as e:
|
||||
# Let's do something manually - sometimes GPT responds with something BEFORE the braces:
|
||||
# "I'm sorry, I don't understand. Please try again."{"text": "I'm sorry, I don't understand. Please try again.", "confidence": 0.0}
|
||||
# So let's try to find the first brace and then parse the rest of the string
|
||||
except json.JSONDecodeError as _: # noqa: F841
|
||||
json_str = correct_json(json_str)
|
||||
try:
|
||||
brace_index = json_str.index("{")
|
||||
json_str = json_str[brace_index:]
|
||||
last_brace_index = json_str.rindex("}")
|
||||
json_str = json_str[:last_brace_index+1]
|
||||
return json.loads(json_str)
|
||||
except Exception as e:
|
||||
if try_to_fix_with_gpt:
|
||||
print(f"Warning: Failed to parse AI output, attempting to fix.\n If you see this warning frequently, it's likely that your prompt is confusing the AI. Try changing it up slightly.")
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as _: # noqa: F841
|
||||
pass
|
||||
# Let's do something manually:
|
||||
# sometimes GPT responds with something BEFORE the braces:
|
||||
# "I'm sorry, I don't understand. Please try again."
|
||||
# {"text": "I'm sorry, I don't understand. Please try again.",
|
||||
# "confidence": 0.0}
|
||||
# So let's try to find the first brace and then parse the rest
|
||||
# of the string
|
||||
try:
|
||||
brace_index = json_str.index("{")
|
||||
json_str = json_str[brace_index:]
|
||||
last_brace_index = json_str.rindex("}")
|
||||
json_str = json_str[:last_brace_index+1]
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e: # noqa: F841
|
||||
if try_to_fix_with_gpt:
|
||||
print("Warning: Failed to parse AI output, attempting to fix."
|
||||
"\n If you see this warning frequently, it's likely that"
|
||||
" your prompt is confusing the AI. Try changing it up"
|
||||
" slightly.")
|
||||
# Now try to fix this up using the ai_functions
|
||||
ai_fixed_json = fix_json(json_str, json_schema, False)
|
||||
ai_fixed_json = fix_json(json_str, JSON_SCHEMA, cfg.debug)
|
||||
if ai_fixed_json != "failed":
|
||||
return json.loads(ai_fixed_json)
|
||||
return json.loads(ai_fixed_json)
|
||||
else:
|
||||
print(f"Failed to fix ai output, telling the AI.") # This allows the AI to react to the error message, which usually results in it correcting its ways.
|
||||
return json_str
|
||||
else:
|
||||
# This allows the AI to react to the error message,
|
||||
# which usually results in it correcting its ways.
|
||||
print("Failed to fix ai output, telling the AI.")
|
||||
return json_str
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
def fix_json(json_str: str, schema: str, debug=False) -> str:
|
||||
"""Fix the given JSON string to make it parseable and fully complient with the provided schema."""
|
||||
# Try to fix the JSON using gpt:
|
||||
function_string = "def fix_json(json_str: str, schema:str=None) -> str:"
|
||||
args = [f"'''{json_str}'''", f"'''{schema}'''"]
|
||||
description_string = """Fixes the provided JSON string to make it parseable and fully complient with the provided schema.\n If an object or field specified in the schema isn't contained within the correct JSON, it is ommited.\n This function is brilliant at guessing when the format is incorrect."""
|
||||
description_string = "Fixes the provided JSON string to make it parseable"\
|
||||
" and fully complient with the provided schema.\n If an object or"\
|
||||
" field specified in the schema isn't contained within the correct"\
|
||||
" JSON, it is ommited.\n This function is brilliant at guessing"\
|
||||
" when the format is incorrect."
|
||||
|
||||
# If it doesn't already start with a "`", add one:
|
||||
if not json_str.startswith("`"):
|
||||
json_str = "```json\n" + json_str + "\n```"
|
||||
json_str = "```json\n" + json_str + "\n```"
|
||||
result_string = call_ai_function(
|
||||
function_string, args, description_string, model=cfg.fast_llm_model
|
||||
)
|
||||
@@ -67,10 +95,11 @@ def fix_json(json_str: str, schema: str, debug=False) -> str:
|
||||
print("-----------")
|
||||
print(f"Fixed JSON: {result_string}")
|
||||
print("----------- END OF FIX ATTEMPT ----------------")
|
||||
|
||||
try:
|
||||
json.loads(result_string) # just check the validity
|
||||
json.loads(result_string) # just check the validity
|
||||
return result_string
|
||||
except:
|
||||
except: # noqa: E722
|
||||
# Get the call stack:
|
||||
# import traceback
|
||||
# call_stack = traceback.format_exc()
|
||||
|
||||
127
scripts/json_utils.py
Normal file
127
scripts/json_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import re
|
||||
import json
|
||||
from config import Config
|
||||
|
||||
cfg = Config()
|
||||
|
||||
|
||||
def extract_char_position(error_message: str) -> int:
|
||||
"""Extract the character position from the JSONDecodeError message.
|
||||
|
||||
Args:
|
||||
error_message (str): The error message from the JSONDecodeError
|
||||
exception.
|
||||
|
||||
Returns:
|
||||
int: The character position.
|
||||
"""
|
||||
import re
|
||||
|
||||
char_pattern = re.compile(r'\(char (\d+)\)')
|
||||
if match := char_pattern.search(error_message):
|
||||
return int(match[1])
|
||||
else:
|
||||
raise ValueError("Character position not found in the error message.")
|
||||
|
||||
|
||||
def add_quotes_to_property_names(json_string: str) -> str:
|
||||
"""
|
||||
Add quotes to property names in a JSON string.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string.
|
||||
|
||||
Returns:
|
||||
str: The JSON string with quotes added to property names.
|
||||
"""
|
||||
|
||||
def replace_func(match):
|
||||
return f'"{match.group(1)}":'
|
||||
|
||||
property_name_pattern = re.compile(r'(\w+):')
|
||||
corrected_json_string = property_name_pattern.sub(
|
||||
replace_func,
|
||||
json_string)
|
||||
|
||||
try:
|
||||
json.loads(corrected_json_string)
|
||||
return corrected_json_string
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
|
||||
|
||||
def balance_braces(json_string: str) -> str:
|
||||
"""
|
||||
Balance the braces in a JSON string.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string.
|
||||
|
||||
Returns:
|
||||
str: The JSON string with braces balanced.
|
||||
"""
|
||||
|
||||
open_braces_count = json_string.count('{')
|
||||
close_braces_count = json_string.count('}')
|
||||
|
||||
while open_braces_count > close_braces_count:
|
||||
json_string += '}'
|
||||
close_braces_count += 1
|
||||
|
||||
while close_braces_count > open_braces_count:
|
||||
json_string = json_string.rstrip('}')
|
||||
close_braces_count -= 1
|
||||
|
||||
try:
|
||||
json.loads(json_string)
|
||||
return json_string
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
|
||||
|
||||
def fix_invalid_escape(json_str: str, error_message: str) -> str:
|
||||
while error_message.startswith('Invalid \\escape'):
|
||||
bad_escape_location = extract_char_position(error_message)
|
||||
json_str = json_str[:bad_escape_location] + \
|
||||
json_str[bad_escape_location + 1:]
|
||||
try:
|
||||
json.loads(json_str)
|
||||
return json_str
|
||||
except json.JSONDecodeError as e:
|
||||
if cfg.debug:
|
||||
print('json loads error - fix invalid escape', e)
|
||||
error_message = str(e)
|
||||
return json_str
|
||||
|
||||
|
||||
def correct_json(json_str: str) -> str:
|
||||
"""
|
||||
Correct common JSON errors.
|
||||
|
||||
Args:
|
||||
json_str (str): The JSON string.
|
||||
"""
|
||||
|
||||
try:
|
||||
if cfg.debug:
|
||||
print("json", json_str)
|
||||
json.loads(json_str)
|
||||
return json_str
|
||||
except json.JSONDecodeError as e:
|
||||
if cfg.debug:
|
||||
print('json loads error', e)
|
||||
error_message = str(e)
|
||||
if error_message.startswith('Invalid \\escape'):
|
||||
json_str = fix_invalid_escape(json_str, error_message)
|
||||
if error_message.startswith('Expecting property name enclosed in double quotes'):
|
||||
json_str = add_quotes_to_property_names(json_str)
|
||||
try:
|
||||
json.loads(json_str)
|
||||
return json_str
|
||||
except json.JSONDecodeError as e:
|
||||
if cfg.debug:
|
||||
print('json loads error - add quotes', e)
|
||||
error_message = str(e)
|
||||
if balanced_str := balance_braces(json_str):
|
||||
return balanced_str
|
||||
return json_str
|
||||
@@ -6,6 +6,7 @@ openai.api_key = cfg.openai_api_key
|
||||
|
||||
# Overly simple abstraction until we create something better
|
||||
def create_chat_completion(messages, model=None, temperature=None, max_tokens=None)->str:
|
||||
"""Create a chat completion using the OpenAI API"""
|
||||
if cfg.use_azure:
|
||||
response = openai.ChatCompletion.create(
|
||||
deployment_id=cfg.openai_deployment_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import random
|
||||
import commands as cmd
|
||||
from memory import PineconeMemory
|
||||
from memory import get_memory
|
||||
import data
|
||||
import chat
|
||||
from colorama import Fore, Style
|
||||
@@ -25,6 +25,7 @@ def print_to_console(
|
||||
speak_text=False,
|
||||
min_typing_speed=0.05,
|
||||
max_typing_speed=0.01):
|
||||
"""Prints text to the console with a typing effect"""
|
||||
global cfg
|
||||
if speak_text and cfg.speak_mode:
|
||||
speak.say_text(f"{title}. {content}")
|
||||
@@ -46,6 +47,7 @@ def print_to_console(
|
||||
|
||||
|
||||
def print_assistant_thoughts(assistant_reply):
|
||||
"""Prints the assistant's thoughts to the console"""
|
||||
global ai_name
|
||||
global cfg
|
||||
try:
|
||||
@@ -105,7 +107,7 @@ def print_assistant_thoughts(assistant_reply):
|
||||
|
||||
|
||||
def load_variables(config_file="config.yaml"):
|
||||
# Load variables from yaml file if it exists
|
||||
"""Load variables from yaml file if it exists, otherwise prompt the user for input"""
|
||||
try:
|
||||
with open(config_file) as file:
|
||||
config = yaml.load(file, Loader=yaml.FullLoader)
|
||||
@@ -159,6 +161,7 @@ def load_variables(config_file="config.yaml"):
|
||||
|
||||
|
||||
def construct_prompt():
|
||||
"""Construct the prompt for the AI to respond to"""
|
||||
config = AIConfig.load()
|
||||
if config.ai_name:
|
||||
print_to_console(
|
||||
@@ -187,6 +190,7 @@ Continue (y/n): """)
|
||||
|
||||
|
||||
def prompt_user():
|
||||
"""Prompt the user for input"""
|
||||
ai_name = ""
|
||||
# Construct the prompt
|
||||
print_to_console(
|
||||
@@ -239,6 +243,7 @@ def prompt_user():
|
||||
return config
|
||||
|
||||
def parse_arguments():
|
||||
"""Parses the arguments passed to the script"""
|
||||
global cfg
|
||||
cfg.set_continuous_mode(False)
|
||||
cfg.set_speak_mode(False)
|
||||
@@ -266,6 +271,10 @@ def parse_arguments():
|
||||
print_to_console("GPT3.5 Only Mode: ", Fore.GREEN, "ENABLED")
|
||||
cfg.set_smart_llm_model(cfg.fast_llm_model)
|
||||
|
||||
if args.debug:
|
||||
print_to_console("Debug Mode: ", Fore.GREEN, "ENABLED")
|
||||
cfg.set_debug_mode(True)
|
||||
|
||||
|
||||
# TODO: fill in llm values here
|
||||
|
||||
@@ -283,9 +292,7 @@ user_input = "Determine which next command to use, and respond using the format
|
||||
|
||||
# Initialize memory and make sure it is empty.
|
||||
# this is particularly important for indexing and referencing pinecone memory
|
||||
memory = PineconeMemory()
|
||||
memory.clear()
|
||||
|
||||
memory = get_memory(cfg, init=True)
|
||||
print('Using memory of type: ' + memory.__class__.__name__)
|
||||
|
||||
# Interaction Loop
|
||||
@@ -297,7 +304,7 @@ while True:
|
||||
user_input,
|
||||
full_message_history,
|
||||
memory,
|
||||
cfg.fast_token_limit) # TODO: This hardcodes the model to use GPT3.5. Make this an argument
|
||||
cfg.fast_token_limit, cfg.debug) # TODO: This hardcodes the model to use GPT3.5. Make this an argument
|
||||
|
||||
# Print Assistant thoughts
|
||||
print_assistant_thoughts(assistant_reply)
|
||||
@@ -357,7 +364,7 @@ while True:
|
||||
f"COMMAND = {Fore.CYAN}{command_name}{Style.RESET_ALL} ARGUMENTS = {Fore.CYAN}{arguments}{Style.RESET_ALL}")
|
||||
|
||||
# Execute command
|
||||
if command_name.lower() == "error":
|
||||
if command_name.lower().startswith( "error" ):
|
||||
result = f"Command {command_name} threw the following error: " + arguments
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
|
||||
44
scripts/memory/__init__.py
Normal file
44
scripts/memory/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from memory.local import LocalCache
|
||||
try:
|
||||
from memory.redismem import RedisMemory
|
||||
except ImportError:
|
||||
print("Redis not installed. Skipping import.")
|
||||
RedisMemory = None
|
||||
|
||||
try:
|
||||
from memory.pinecone import PineconeMemory
|
||||
except ImportError:
|
||||
print("Pinecone not installed. Skipping import.")
|
||||
PineconeMemory = None
|
||||
|
||||
|
||||
def get_memory(cfg, init=False):
|
||||
memory = None
|
||||
if cfg.memory_backend == "pinecone":
|
||||
if not PineconeMemory:
|
||||
print("Error: Pinecone is not installed. Please install pinecone"
|
||||
" to use Pinecone as a memory backend.")
|
||||
else:
|
||||
memory = PineconeMemory(cfg)
|
||||
if init:
|
||||
memory.clear()
|
||||
elif cfg.memory_backend == "redis":
|
||||
if not RedisMemory:
|
||||
print("Error: Redis is not installed. Please install redis-py to"
|
||||
" use Redis as a memory backend.")
|
||||
else:
|
||||
memory = RedisMemory(cfg)
|
||||
|
||||
if memory is None:
|
||||
memory = LocalCache(cfg)
|
||||
if init:
|
||||
memory.clear()
|
||||
return memory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_memory",
|
||||
"LocalCache",
|
||||
"RedisMemory",
|
||||
"PineconeMemory",
|
||||
]
|
||||
31
scripts/memory/base.py
Normal file
31
scripts/memory/base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Base class for memory providers."""
|
||||
import abc
|
||||
from config import AbstractSingleton
|
||||
import openai
|
||||
|
||||
|
||||
def get_ada_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
|
||||
class MemoryProviderSingleton(AbstractSingleton):
|
||||
@abc.abstractmethod
|
||||
def add(self, data):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, data):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_relevant(self, data, num_relevant=5):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stats(self):
|
||||
pass
|
||||
114
scripts/memory/local.py
Normal file
114
scripts/memory/local.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import dataclasses
|
||||
import orjson
|
||||
from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
import os
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
|
||||
|
||||
EMBED_DIM = 1536
|
||||
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
|
||||
|
||||
|
||||
def create_default_embeddings():
|
||||
return np.zeros((0, EMBED_DIM)).astype(np.float32)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CacheContent:
|
||||
texts: List[str] = dataclasses.field(default_factory=list)
|
||||
embeddings: np.ndarray = dataclasses.field(
|
||||
default_factory=create_default_embeddings
|
||||
)
|
||||
|
||||
|
||||
class LocalCache(MemoryProviderSingleton):
|
||||
|
||||
# on load, load our database
|
||||
def __init__(self, cfg) -> None:
|
||||
self.filename = f"{cfg.memory_index}.json"
|
||||
if os.path.exists(self.filename):
|
||||
with open(self.filename, 'rb') as f:
|
||||
loaded = orjson.loads(f.read())
|
||||
self.data = CacheContent(**loaded)
|
||||
else:
|
||||
self.data = CacheContent()
|
||||
|
||||
def add(self, text: str):
|
||||
"""
|
||||
Add text to our list of texts, add embedding as row to our
|
||||
embeddings-matrix
|
||||
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
if 'Command Error:' in text:
|
||||
return ""
|
||||
self.data.texts.append(text)
|
||||
|
||||
embedding = get_ada_embedding(text)
|
||||
|
||||
vector = np.array(embedding).astype(np.float32)
|
||||
vector = vector[np.newaxis, :]
|
||||
self.data.embeddings = np.concatenate(
|
||||
[
|
||||
vector,
|
||||
self.data.embeddings,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
with open(self.filename, 'wb') as f:
|
||||
out = orjson.dumps(
|
||||
self.data,
|
||||
option=SAVE_OPTIONS
|
||||
)
|
||||
f.write(out)
|
||||
return text
|
||||
|
||||
def clear(self) -> str:
|
||||
"""
|
||||
Clears the redis server.
|
||||
|
||||
Returns: A message indicating that the memory has been cleared.
|
||||
"""
|
||||
self.data = CacheContent()
|
||||
return "Obliviated"
|
||||
|
||||
def get(self, data: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
Gets the data from the memory that is most relevant to the given data.
|
||||
|
||||
Args:
|
||||
data: The data to compare to.
|
||||
|
||||
Returns: The most relevant data.
|
||||
"""
|
||||
return self.get_relevant(data, 1)
|
||||
|
||||
def get_relevant(self, text: str, k: int) -> List[Any]:
|
||||
""""
|
||||
matrix-vector mult to find score-for-each-row-of-matrix
|
||||
get indices for top-k winning scores
|
||||
return texts for those indices
|
||||
Args:
|
||||
text: str
|
||||
k: int
|
||||
|
||||
Returns: List[str]
|
||||
"""
|
||||
embedding = get_ada_embedding(text)
|
||||
|
||||
scores = np.dot(self.data.embeddings, embedding)
|
||||
|
||||
top_k_indices = np.argsort(scores)[-k:][::-1]
|
||||
|
||||
return [self.data.texts[i] for i in top_k_indices]
|
||||
|
||||
def get_stats(self):
|
||||
"""
|
||||
Returns: The stats of the local cache.
|
||||
"""
|
||||
return len(self.data.texts), self.data.embeddings.shape
|
||||
@@ -1,21 +1,11 @@
|
||||
from config import Config, Singleton
|
||||
|
||||
import pinecone
|
||||
import openai
|
||||
|
||||
cfg = Config()
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
|
||||
|
||||
def get_ada_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
|
||||
def get_text_from_embedding(embedding):
|
||||
return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"]
|
||||
|
||||
|
||||
class PineconeMemory(metaclass=Singleton):
|
||||
def __init__(self):
|
||||
class PineconeMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
pinecone_api_key = cfg.pinecone_api_key
|
||||
pinecone_region = cfg.pinecone_region
|
||||
pinecone.init(api_key=pinecone_api_key, environment=pinecone_region)
|
||||
143
scripts/memory/redismem.py
Normal file
143
scripts/memory/redismem.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Redis memory provider."""
|
||||
from typing import Any, List, Optional
|
||||
import redis
|
||||
from redis.commands.search.field import VectorField, TextField
|
||||
from redis.commands.search.query import Query
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
import numpy as np
|
||||
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
|
||||
|
||||
SCHEMA = [
|
||||
TextField("data"),
|
||||
VectorField(
|
||||
"embedding",
|
||||
"HNSW",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": 1536,
|
||||
"DISTANCE_METRIC": "COSINE"
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RedisMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
"""
|
||||
Initializes the Redis memory provider.
|
||||
|
||||
Args:
|
||||
cfg: The config object.
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
redis_host = cfg.redis_host
|
||||
redis_port = cfg.redis_port
|
||||
redis_password = cfg.redis_password
|
||||
self.dimension = 1536
|
||||
self.redis = redis.Redis(
|
||||
host=redis_host,
|
||||
port=redis_port,
|
||||
password=redis_password,
|
||||
db=0 # Cannot be changed
|
||||
)
|
||||
self.cfg = cfg
|
||||
if cfg.wipe_redis_on_start:
|
||||
self.redis.flushall()
|
||||
try:
|
||||
self.redis.ft(f"{cfg.memory_index}").create_index(
|
||||
fields=SCHEMA,
|
||||
definition=IndexDefinition(
|
||||
prefix=[f"{cfg.memory_index}:"],
|
||||
index_type=IndexType.HASH
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating Redis search index: ", e)
|
||||
existing_vec_num = self.redis.get(f'{cfg.memory_index}-vec_num')
|
||||
self.vec_num = int(existing_vec_num.decode('utf-8')) if\
|
||||
existing_vec_num else 0
|
||||
|
||||
def add(self, data: str) -> str:
|
||||
"""
|
||||
Adds a data point to the memory.
|
||||
|
||||
Args:
|
||||
data: The data to add.
|
||||
|
||||
Returns: Message indicating that the data has been added.
|
||||
"""
|
||||
if 'Command Error:' in data:
|
||||
return ""
|
||||
vector = get_ada_embedding(data)
|
||||
vector = np.array(vector).astype(np.float32).tobytes()
|
||||
data_dict = {
|
||||
b"data": data,
|
||||
"embedding": vector
|
||||
}
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.hset(f"{self.cfg.memory_index}:{self.vec_num}", mapping=data_dict)
|
||||
_text = f"Inserting data into memory at index: {self.vec_num}:\n"\
|
||||
f"data: {data}"
|
||||
self.vec_num += 1
|
||||
pipe.set(f'{self.cfg.memory_index}-vec_num', self.vec_num)
|
||||
pipe.execute()
|
||||
return _text
|
||||
|
||||
def get(self, data: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
Gets the data from the memory that is most relevant to the given data.
|
||||
|
||||
Args:
|
||||
data: The data to compare to.
|
||||
|
||||
Returns: The most relevant data.
|
||||
"""
|
||||
return self.get_relevant(data, 1)
|
||||
|
||||
def clear(self) -> str:
|
||||
"""
|
||||
Clears the redis server.
|
||||
|
||||
Returns: A message indicating that the memory has been cleared.
|
||||
"""
|
||||
self.redis.flushall()
|
||||
return "Obliviated"
|
||||
|
||||
def get_relevant(
|
||||
self,
|
||||
data: str,
|
||||
num_relevant: int = 5
|
||||
) -> Optional[List[Any]]:
|
||||
"""
|
||||
Returns all the data in the memory that is relevant to the given data.
|
||||
Args:
|
||||
data: The data to compare to.
|
||||
num_relevant: The number of relevant data to return.
|
||||
|
||||
Returns: A list of the most relevant data.
|
||||
"""
|
||||
query_embedding = get_ada_embedding(data)
|
||||
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
||||
query = Query(base_query).return_fields(
|
||||
"data",
|
||||
"vector_score"
|
||||
).sort_by("vector_score").dialect(2)
|
||||
query_vector = np.array(query_embedding).astype(np.float32).tobytes()
|
||||
|
||||
try:
|
||||
results = self.redis.ft(f"{self.cfg.memory_index}").search(
|
||||
query, query_params={"vector": query_vector}
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error calling Redis search: ", e)
|
||||
return None
|
||||
return [result.data for result in results.docs]
|
||||
|
||||
def get_stats(self):
|
||||
"""
|
||||
Returns: The stats of the memory index.
|
||||
"""
|
||||
return self.redis.ft(f"{self.cfg.memory_index}").info()
|
||||
@@ -15,6 +15,7 @@ tts_headers = {
|
||||
}
|
||||
|
||||
def eleven_labs_speech(text, voice_index=0):
|
||||
"""Speak text using elevenlabs.io's API"""
|
||||
tts_url = "https://api.elevenlabs.io/v1/text-to-speech/{voice_id}".format(
|
||||
voice_id=voices[voice_index])
|
||||
formatted_message = {"text": text}
|
||||
|
||||
@@ -5,7 +5,9 @@ import time
|
||||
|
||||
|
||||
class Spinner:
|
||||
"""A simple spinner class"""
|
||||
def __init__(self, message="Loading...", delay=0.1):
|
||||
"""Initialize the spinner class"""
|
||||
self.spinner = itertools.cycle(['-', '/', '|', '\\'])
|
||||
self.delay = delay
|
||||
self.message = message
|
||||
@@ -13,6 +15,7 @@ class Spinner:
|
||||
self.spinner_thread = None
|
||||
|
||||
def spin(self):
|
||||
"""Spin the spinner"""
|
||||
while self.running:
|
||||
sys.stdout.write(next(self.spinner) + " " + self.message + "\r")
|
||||
sys.stdout.flush()
|
||||
@@ -20,11 +23,13 @@ class Spinner:
|
||||
sys.stdout.write('\b' * (len(self.message) + 2))
|
||||
|
||||
def __enter__(self):
|
||||
"""Start the spinner"""
|
||||
self.running = True
|
||||
self.spinner_thread = threading.Thread(target=self.spin)
|
||||
self.spinner_thread.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
"""Stop the spinner"""
|
||||
self.running = False
|
||||
self.spinner_thread.join()
|
||||
sys.stdout.write('\r' + ' ' * (len(self.message) + 2) + '\r')
|
||||
|
||||
Reference in New Issue
Block a user