diff --git a/scripts/main.py b/scripts/main.py index d7b94085c6..0a4e97a256 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -2,7 +2,7 @@ import json import random import commands as cmd import utils -from memory import get_memory +from memory import get_memory, get_supported_memory_backends import data import chat from colorama import Fore, Style @@ -275,6 +275,7 @@ def parse_arguments(): parser.add_argument('--debug', action='store_true', help='Enable Debug Mode') parser.add_argument('--gpt3only', action='store_true', help='Enable GPT3.5 Only Mode') parser.add_argument('--gpt4only', action='store_true', help='Enable GPT4 Only Mode') + parser.add_argument('--use-memory', '-m', dest="memory_type", help='Defines which Memory backend to use') args = parser.parse_args() if args.debug: @@ -305,6 +306,15 @@ def parse_arguments(): logger.typewriter_log("Debug Mode: ", Fore.GREEN, "ENABLED") cfg.set_debug_mode(True) + if args.memory_type: + supported_memory = get_supported_memory_backends() + chosen = args.memory_type + if not chosen in supported_memory: + print_to_console("ONLY THE FOLLOWING MEMORY BACKENDS ARE SUPPORTED: ", Fore.RED, f'{supported_memory}') + print_to_console(f"Defaulting to: ", Fore.YELLOW, cfg.memory_backend) + else: + cfg.memory_backend = chosen + # TODO: fill in llm values here check_openai_api_key() diff --git a/scripts/memory/__init__.py b/scripts/memory/__init__.py index a441a46aa9..2900353ed9 100644 --- a/scripts/memory/__init__.py +++ b/scripts/memory/__init__.py @@ -1,17 +1,23 @@ from memory.local import LocalCache + +# List of supported memory backends +# Add a backend to this list if the import attempt is successful +supported_memory = ['local'] + try: from memory.redismem import RedisMemory + supported_memory.append('redis') except ImportError: print("Redis not installed. Skipping import.") RedisMemory = None try: from memory.pinecone import PineconeMemory + supported_memory.append('pinecone') except ImportError: print("Pinecone not installed. Skipping import.") PineconeMemory = None - def get_memory(cfg, init=False): memory = None if cfg.memory_backend == "pinecone": @@ -35,6 +41,8 @@ def get_memory(cfg, init=False): memory.clear() return memory +def get_supported_memory_backends(): + return supported_memory __all__ = [ "get_memory",