mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
AutoGPT: Improve function scopes and data flow in app.main and config
* Move TTS related config into TTSConfig
This commit is contained in:
@@ -32,7 +32,11 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
config.workspace_path = Workspace.init_workspace_directory(config)
|
||||
config.file_logger_path = Workspace.build_file_logger_path(config.workspace_path)
|
||||
|
||||
configure_logging(config, LOG_DIR)
|
||||
configure_logging(
|
||||
debug_mode=config.debug_mode,
|
||||
plain_output=config.plain_output,
|
||||
log_dir=LOG_DIR,
|
||||
)
|
||||
|
||||
command_registry = CommandRegistry.with_command_modules(COMMAND_CATEGORIES, config)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import click
|
||||
@click.option(
|
||||
"--ai-settings",
|
||||
"-C",
|
||||
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
||||
help=(
|
||||
"Specifies which ai_settings.yaml file to use, relative to the AutoGPT"
|
||||
" root directory. Will also automatically skip the re-prompt."
|
||||
@@ -24,6 +25,7 @@ import click
|
||||
@click.option(
|
||||
"--prompt-settings",
|
||||
"-P",
|
||||
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
||||
help="Specifies which prompt_settings.yaml file to use.",
|
||||
)
|
||||
@click.option(
|
||||
@@ -92,8 +94,8 @@ def main(
|
||||
ctx: click.Context,
|
||||
continuous: bool,
|
||||
continuous_limit: int,
|
||||
ai_settings: str,
|
||||
prompt_settings: str,
|
||||
ai_settings: Optional[Path],
|
||||
prompt_settings: Optional[Path],
|
||||
skip_reprompt: bool,
|
||||
speak: bool,
|
||||
debug: bool,
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
from colorama import Back, Fore, Style
|
||||
@@ -17,29 +18,29 @@ from autogpt.memory.vector import get_supported_memory_backends
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_config(
|
||||
def apply_overrides_to_config(
|
||||
config: Config,
|
||||
continuous: bool,
|
||||
continuous_limit: int,
|
||||
ai_settings_file: str,
|
||||
prompt_settings_file: str,
|
||||
skip_reprompt: bool,
|
||||
speak: bool,
|
||||
debug: bool,
|
||||
gpt3only: bool,
|
||||
gpt4only: bool,
|
||||
memory_type: str,
|
||||
browser_name: str,
|
||||
allow_downloads: bool,
|
||||
skip_news: bool,
|
||||
continuous: bool = False,
|
||||
continuous_limit: Optional[int] = None,
|
||||
ai_settings_file: Optional[Path] = None,
|
||||
prompt_settings_file: Optional[Path] = None,
|
||||
skip_reprompt: bool = False,
|
||||
speak: bool = False,
|
||||
debug: bool = False,
|
||||
gpt3only: bool = False,
|
||||
gpt4only: bool = False,
|
||||
memory_type: str = "",
|
||||
browser_name: str = "",
|
||||
allow_downloads: bool = False,
|
||||
skip_news: bool = False,
|
||||
) -> None:
|
||||
"""Updates the config object with the given arguments.
|
||||
|
||||
Args:
|
||||
continuous (bool): Whether to run in continuous mode
|
||||
continuous_limit (int): The number of times to run in continuous mode
|
||||
ai_settings_file (str): The path to the ai_settings.yaml file
|
||||
prompt_settings_file (str): The path to the prompt_settings.yaml file
|
||||
ai_settings_file (Path): The path to the ai_settings.yaml file
|
||||
prompt_settings_file (Path): The path to the prompt_settings.yaml file
|
||||
skip_reprompt (bool): Whether to skip the re-prompting messages at the beginning of the script
|
||||
speak (bool): Whether to enable speak mode
|
||||
debug (bool): Whether to enable debug mode
|
||||
@@ -52,7 +53,7 @@ def create_config(
|
||||
"""
|
||||
config.debug_mode = False
|
||||
config.continuous_mode = False
|
||||
config.speak_mode = False
|
||||
config.tts_config.speak_mode = False
|
||||
|
||||
if debug:
|
||||
print_attribute("Debug mode", "ENABLED")
|
||||
@@ -77,7 +78,7 @@ def create_config(
|
||||
|
||||
if speak:
|
||||
print_attribute("Speak Mode", "ENABLED")
|
||||
config.speak_mode = True
|
||||
config.tts_config.speak_mode = True
|
||||
|
||||
# Set the default LLM models
|
||||
if gpt3only:
|
||||
|
||||
@@ -14,7 +14,7 @@ from pydantic import SecretStr
|
||||
from autogpt.agents import AgentThoughts, CommandArgs, CommandName
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.agents.utils.exceptions import InvalidAgentResponseError
|
||||
from autogpt.app.configurator import create_config
|
||||
from autogpt.app.configurator import apply_overrides_to_config
|
||||
from autogpt.app.setup import interactive_ai_config_setup
|
||||
from autogpt.app.spinner import Spinner
|
||||
from autogpt.app.utils import (
|
||||
@@ -25,7 +25,12 @@ from autogpt.app.utils import (
|
||||
markdown_to_ansi_style,
|
||||
)
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIConfig, Config, ConfigBuilder, check_openai_api_key
|
||||
from autogpt.config import (
|
||||
AIConfig,
|
||||
Config,
|
||||
ConfigBuilder,
|
||||
assert_config_has_openai_api_key,
|
||||
)
|
||||
from autogpt.core.resource.model_providers import (
|
||||
ChatModelProvider,
|
||||
ModelProviderCredentials,
|
||||
@@ -46,8 +51,8 @@ from scripts.install_plugin_deps import install_plugin_dependencies
|
||||
async def run_auto_gpt(
|
||||
continuous: bool,
|
||||
continuous_limit: int,
|
||||
ai_settings: str,
|
||||
prompt_settings: str,
|
||||
ai_settings: Optional[Path],
|
||||
prompt_settings: Optional[Path],
|
||||
skip_reprompt: bool,
|
||||
speak: bool,
|
||||
debug: bool,
|
||||
@@ -67,9 +72,9 @@ async def run_auto_gpt(
|
||||
config = ConfigBuilder.build_config_from_env(workdir=working_directory)
|
||||
|
||||
# TODO: fill in llm values here
|
||||
check_openai_api_key(config)
|
||||
assert_config_has_openai_api_key(config)
|
||||
|
||||
create_config(
|
||||
apply_overrides_to_config(
|
||||
config,
|
||||
continuous,
|
||||
continuous_limit,
|
||||
@@ -87,7 +92,11 @@ async def run_auto_gpt(
|
||||
)
|
||||
|
||||
# Set up logging module
|
||||
configure_logging(config)
|
||||
configure_logging(
|
||||
debug_mode=debug,
|
||||
plain_output=config.plain_output,
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
|
||||
@@ -105,39 +114,9 @@ async def run_auto_gpt(
|
||||
)
|
||||
|
||||
if not config.skip_news:
|
||||
motd, is_new_motd = get_latest_bulletin()
|
||||
if motd:
|
||||
motd = markdown_to_ansi_style(motd)
|
||||
for motd_line in motd.split("\n"):
|
||||
logger.info(
|
||||
extra={
|
||||
"title": "NEWS:",
|
||||
"title_color": Fore.GREEN,
|
||||
"preserve_color": True,
|
||||
},
|
||||
msg=motd_line,
|
||||
)
|
||||
if is_new_motd and not config.chat_messages_enabled:
|
||||
input(
|
||||
Fore.MAGENTA
|
||||
+ Style.BRIGHT
|
||||
+ "NEWS: Bulletin was updated! Press Enter to continue..."
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
|
||||
git_branch = get_current_git_branch()
|
||||
if git_branch and git_branch != "stable":
|
||||
logger.warn(
|
||||
f"You are running on `{git_branch}` branch"
|
||||
" - this is not a supported branch."
|
||||
)
|
||||
if sys.version_info < (3, 10):
|
||||
logger.error(
|
||||
"WARNING: You are running on an older version of Python. "
|
||||
"Some people have observed problems with certain "
|
||||
"parts of AutoGPT with this version. "
|
||||
"Please consider upgrading to Python 3.10 or higher.",
|
||||
)
|
||||
print_motd(config, logger)
|
||||
print_git_branch_info(logger)
|
||||
print_python_version_info(logger)
|
||||
|
||||
if install_plugin_deps:
|
||||
install_plugin_dependencies()
|
||||
@@ -165,15 +144,6 @@ async def run_auto_gpt(
|
||||
role=ai_role,
|
||||
goals=ai_goals,
|
||||
)
|
||||
# print(prompt)
|
||||
|
||||
# Initialize memory and make sure it is empty.
|
||||
# this is particularly important for indexing and referencing pinecone memory
|
||||
memory = get_memory(config)
|
||||
memory.clear()
|
||||
print_attribute("Configured Memory", memory.__class__.__name__)
|
||||
|
||||
print_attribute("Configured Browser", config.selenium_web_browser)
|
||||
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
@@ -192,6 +162,14 @@ async def run_auto_gpt(
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
# Initialize memory and make sure it is empty.
|
||||
# this is particularly important for indexing and referencing pinecone memory
|
||||
memory = get_memory(config)
|
||||
memory.clear()
|
||||
print_attribute("Configured Memory", memory.__class__.__name__)
|
||||
|
||||
print_attribute("Configured Browser", config.selenium_web_browser)
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=llm_provider,
|
||||
@@ -203,6 +181,47 @@ async def run_auto_gpt(
|
||||
await run_interaction_loop(agent)
|
||||
|
||||
|
||||
def print_motd(config: Config, logger: logging.Logger):
|
||||
motd, is_new_motd = get_latest_bulletin()
|
||||
if motd:
|
||||
motd = markdown_to_ansi_style(motd)
|
||||
for motd_line in motd.split("\n"):
|
||||
logger.info(
|
||||
extra={
|
||||
"title": "NEWS:",
|
||||
"title_color": Fore.GREEN,
|
||||
"preserve_color": True,
|
||||
},
|
||||
msg=motd_line,
|
||||
)
|
||||
if is_new_motd and not config.chat_messages_enabled:
|
||||
input(
|
||||
Fore.MAGENTA
|
||||
+ Style.BRIGHT
|
||||
+ "NEWS: Bulletin was updated! Press Enter to continue..."
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
|
||||
|
||||
def print_git_branch_info(logger: logging.Logger):
|
||||
git_branch = get_current_git_branch()
|
||||
if git_branch and git_branch != "stable":
|
||||
logger.warn(
|
||||
f"You are running on `{git_branch}` branch"
|
||||
" - this is not a supported branch."
|
||||
)
|
||||
|
||||
|
||||
def print_python_version_info(logger: logging.Logger):
|
||||
if sys.version_info < (3, 10):
|
||||
logger.error(
|
||||
"WARNING: You are running on an older version of Python. "
|
||||
"Some people have observed problems with certain "
|
||||
"parts of AutoGPT with this version. "
|
||||
"Please consider upgrading to Python 3.10 or higher.",
|
||||
)
|
||||
|
||||
|
||||
def _configure_openai_provider(config: Config) -> OpenAIProvider:
|
||||
"""Create a configured OpenAIProvider object.
|
||||
|
||||
@@ -331,7 +350,11 @@ async def run_interaction_loop(
|
||||
###############
|
||||
# Print the assistant's thoughts and the next command to the user.
|
||||
update_user(
|
||||
legacy_config, ai_config, command_name, command_args, assistant_reply_dict
|
||||
ai_config,
|
||||
command_name,
|
||||
command_args,
|
||||
assistant_reply_dict,
|
||||
speak_mode=legacy_config.tts_config.speak_mode,
|
||||
)
|
||||
|
||||
##################
|
||||
@@ -405,11 +428,11 @@ async def run_interaction_loop(
|
||||
|
||||
|
||||
def update_user(
|
||||
config: Config,
|
||||
ai_config: AIConfig,
|
||||
command_name: CommandName,
|
||||
command_args: CommandArgs,
|
||||
assistant_reply_dict: AgentThoughts,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
"""Prints the assistant's thoughts and the next command to the user.
|
||||
|
||||
@@ -422,9 +445,13 @@ def update_user(
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
print_assistant_thoughts(ai_config.ai_name, assistant_reply_dict, config)
|
||||
print_assistant_thoughts(
|
||||
ai_name=ai_config.ai_name,
|
||||
assistant_reply_json_valid=assistant_reply_dict,
|
||||
speak_mode=speak_mode,
|
||||
)
|
||||
|
||||
if config.speak_mode:
|
||||
if speak_mode:
|
||||
speak(f"I want to execute {command_name}")
|
||||
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
@@ -589,7 +616,7 @@ Continue ({config.authorise_key}/{config.exit_key}): """,
|
||||
def print_assistant_thoughts(
|
||||
ai_name: str,
|
||||
assistant_reply_json_valid: dict,
|
||||
config: Config,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -634,7 +661,7 @@ def print_assistant_thoughts(
|
||||
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak:
|
||||
if config.speak_mode:
|
||||
if speak_mode:
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute("SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW)
|
||||
|
||||
@@ -3,10 +3,10 @@ This module contains the configuration classes for AutoGPT.
|
||||
"""
|
||||
from .ai_config import AIConfig
|
||||
from .ai_directives import AIDirectives
|
||||
from .config import Config, ConfigBuilder, check_openai_api_key
|
||||
from .config import Config, ConfigBuilder, assert_config_has_openai_api_key
|
||||
|
||||
__all__ = [
|
||||
"check_openai_api_key",
|
||||
"assert_config_has_openai_api_key",
|
||||
"AIConfig",
|
||||
"AIDirectives",
|
||||
"Config",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
@@ -25,7 +26,7 @@ class AIDirectives(BaseModel):
|
||||
best_practices: list[str]
|
||||
|
||||
@staticmethod
|
||||
def from_file(prompt_settings_file: str) -> AIDirectives:
|
||||
def from_file(prompt_settings_file: Path) -> AIDirectives:
|
||||
(validated, message) = validate_yaml_file(prompt_settings_file)
|
||||
if not validated:
|
||||
logger.error(message, extra={"title": "FAILED FILE VALIDATION"})
|
||||
|
||||
@@ -15,11 +15,12 @@ from pydantic import Field, validator
|
||||
from autogpt.core.configuration.schema import Configurable, SystemSettings
|
||||
from autogpt.core.resource.model_providers.openai import OPEN_AI_CHAT_MODELS
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
AI_SETTINGS_FILE = "ai_settings.yaml"
|
||||
AZURE_CONFIG_FILE = "azure.yaml"
|
||||
PLUGINS_CONFIG_FILE = "plugins_config.yaml"
|
||||
PROMPT_SETTINGS_FILE = "prompt_settings.yaml"
|
||||
AI_SETTINGS_FILE = Path("ai_settings.yaml")
|
||||
AZURE_CONFIG_FILE = Path("azure.yaml")
|
||||
PLUGINS_CONFIG_FILE = Path("plugins_config.yaml")
|
||||
PROMPT_SETTINGS_FILE = Path("prompt_settings.yaml")
|
||||
|
||||
GPT_4_MODEL = "gpt-4"
|
||||
GPT_3_MODEL = "gpt-3.5-turbo"
|
||||
@@ -31,6 +32,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
########################
|
||||
# Application Settings #
|
||||
########################
|
||||
workdir: Path = None
|
||||
skip_news: bool = False
|
||||
skip_reprompt: bool = False
|
||||
authorise_key: str = "y"
|
||||
@@ -40,18 +42,14 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
noninteractive_mode: bool = False
|
||||
chat_messages_enabled: bool = True
|
||||
# TTS configuration
|
||||
speak_mode: bool = False
|
||||
text_to_speech_provider: str = "gtts"
|
||||
streamelements_voice: str = "Brian"
|
||||
elevenlabs_voice_id: Optional[str] = None
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
##########################
|
||||
# Agent Control Settings #
|
||||
##########################
|
||||
# Paths
|
||||
ai_settings_file: str = AI_SETTINGS_FILE
|
||||
prompt_settings_file: str = PROMPT_SETTINGS_FILE
|
||||
workdir: Path = None
|
||||
ai_settings_file: Path = AI_SETTINGS_FILE
|
||||
prompt_settings_file: Path = PROMPT_SETTINGS_FILE
|
||||
workspace_path: Optional[Path] = None
|
||||
file_logger_path: Optional[Path] = None
|
||||
# Model configuration
|
||||
@@ -105,7 +103,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Plugin Settings #
|
||||
###################
|
||||
plugins_dir: str = "plugins"
|
||||
plugins_config_file: str = PLUGINS_CONFIG_FILE
|
||||
plugins_config_file: Path = PLUGINS_CONFIG_FILE
|
||||
plugins_config: PluginsConfig = Field(
|
||||
default_factory=lambda: PluginsConfig(plugins={})
|
||||
)
|
||||
@@ -124,10 +122,8 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
openai_api_version: Optional[str] = None
|
||||
openai_organization: Optional[str] = None
|
||||
use_azure: bool = False
|
||||
azure_config_file: Optional[str] = AZURE_CONFIG_FILE
|
||||
azure_config_file: Optional[Path] = AZURE_CONFIG_FILE
|
||||
azure_model_to_deployment_id_map: Optional[Dict[str, str]] = None
|
||||
# Elevenlabs
|
||||
elevenlabs_api_key: Optional[str] = None
|
||||
# Github
|
||||
github_api_key: Optional[str] = None
|
||||
github_username: Optional[str] = None
|
||||
@@ -233,9 +229,9 @@ class ConfigBuilder(Configurable[Config]):
|
||||
"exit_key": os.getenv("EXIT_KEY"),
|
||||
"plain_output": os.getenv("PLAIN_OUTPUT", "False") == "True",
|
||||
"shell_command_control": os.getenv("SHELL_COMMAND_CONTROL"),
|
||||
"ai_settings_file": os.getenv("AI_SETTINGS_FILE", AI_SETTINGS_FILE),
|
||||
"prompt_settings_file": os.getenv(
|
||||
"PROMPT_SETTINGS_FILE", PROMPT_SETTINGS_FILE
|
||||
"ai_settings_file": Path(os.getenv("AI_SETTINGS_FILE", AI_SETTINGS_FILE)),
|
||||
"prompt_settings_file": Path(
|
||||
os.getenv("PROMPT_SETTINGS_FILE", PROMPT_SETTINGS_FILE)
|
||||
),
|
||||
"fast_llm": os.getenv("FAST_LLM", os.getenv("FAST_LLM_MODEL")),
|
||||
"smart_llm": os.getenv("SMART_LLM", os.getenv("SMART_LLM_MODEL")),
|
||||
@@ -243,15 +239,17 @@ class ConfigBuilder(Configurable[Config]):
|
||||
"browse_spacy_language_model": os.getenv("BROWSE_SPACY_LANGUAGE_MODEL"),
|
||||
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"use_azure": os.getenv("USE_AZURE") == "True",
|
||||
"azure_config_file": os.getenv("AZURE_CONFIG_FILE", AZURE_CONFIG_FILE),
|
||||
"azure_config_file": Path(
|
||||
os.getenv("AZURE_CONFIG_FILE", AZURE_CONFIG_FILE)
|
||||
),
|
||||
"execute_local_commands": os.getenv("EXECUTE_LOCAL_COMMANDS", "False")
|
||||
== "True",
|
||||
"restrict_to_workspace": os.getenv("RESTRICT_TO_WORKSPACE", "True")
|
||||
== "True",
|
||||
"openai_functions": os.getenv("OPENAI_FUNCTIONS", "False") == "True",
|
||||
"elevenlabs_api_key": os.getenv("ELEVENLABS_API_KEY"),
|
||||
"streamelements_voice": os.getenv("STREAMELEMENTS_VOICE"),
|
||||
"text_to_speech_provider": os.getenv("TEXT_TO_SPEECH_PROVIDER"),
|
||||
"tts_config": {
|
||||
"provider": os.getenv("TEXT_TO_SPEECH_PROVIDER"),
|
||||
},
|
||||
"github_api_key": os.getenv("GITHUB_API_KEY"),
|
||||
"github_username": os.getenv("GITHUB_USERNAME"),
|
||||
"google_api_key": os.getenv("GOOGLE_API_KEY"),
|
||||
@@ -273,8 +271,8 @@ class ConfigBuilder(Configurable[Config]):
|
||||
"redis_password": os.getenv("REDIS_PASSWORD"),
|
||||
"wipe_redis_on_start": os.getenv("WIPE_REDIS_ON_START", "True") == "True",
|
||||
"plugins_dir": os.getenv("PLUGINS_DIR"),
|
||||
"plugins_config_file": os.getenv(
|
||||
"PLUGINS_CONFIG_FILE", PLUGINS_CONFIG_FILE
|
||||
"plugins_config_file": Path(
|
||||
os.getenv("PLUGINS_CONFIG_FILE", PLUGINS_CONFIG_FILE)
|
||||
),
|
||||
"chat_messages_enabled": os.getenv("CHAT_MESSAGES_ENABLED") == "True",
|
||||
}
|
||||
@@ -294,19 +292,26 @@ class ConfigBuilder(Configurable[Config]):
|
||||
"GOOGLE_CUSTOM_SEARCH_ENGINE_ID", os.getenv("CUSTOM_SEARCH_ENGINE_ID")
|
||||
)
|
||||
|
||||
config_dict["elevenlabs_voice_id"] = os.getenv(
|
||||
"ELEVENLABS_VOICE_ID", os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
)
|
||||
if not config_dict["text_to_speech_provider"]:
|
||||
if os.getenv("ELEVENLABS_API_KEY"):
|
||||
config_dict["tts_config"]["elevenlabs"] = {
|
||||
"api_key": os.getenv("ELEVENLABS_API_KEY"),
|
||||
"voice_id": os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||
}
|
||||
if os.getenv("STREAMELEMENTS_VOICE"):
|
||||
config_dict["tts_config"]["streamelements"] = {
|
||||
"voice": os.getenv("STREAMELEMENTS_VOICE"),
|
||||
}
|
||||
|
||||
if not config_dict["tts_config"]["provider"]:
|
||||
if os.getenv("USE_MAC_OS_TTS"):
|
||||
default_tts_provider = "macos"
|
||||
elif config_dict["elevenlabs_api_key"]:
|
||||
elif "elevenlabs" in config_dict["tts_config"]:
|
||||
default_tts_provider = "elevenlabs"
|
||||
elif os.getenv("USE_BRIAN_TTS"):
|
||||
default_tts_provider = "streamelements"
|
||||
else:
|
||||
default_tts_provider = "gtts"
|
||||
config_dict["text_to_speech_provider"] = default_tts_provider
|
||||
config_dict["tts_config"]["provider"] = default_tts_provider
|
||||
|
||||
config_dict["plugins_allowlist"] = _safe_split(os.getenv("ALLOWLISTED_PLUGINS"))
|
||||
config_dict["plugins_denylist"] = _safe_split(os.getenv("DENYLISTED_PLUGINS"))
|
||||
@@ -374,7 +379,7 @@ class ConfigBuilder(Configurable[Config]):
|
||||
}
|
||||
|
||||
|
||||
def check_openai_api_key(config: Config) -> None:
|
||||
def assert_config_has_openai_api_key(config: Config) -> None:
|
||||
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
|
||||
if not config.openai_api_key:
|
||||
print(
|
||||
|
||||
@@ -4,13 +4,14 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openai.util import logger as openai_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
from autogpt.core.runner.client_lib.logging import BelowLevelFilter
|
||||
|
||||
@@ -34,15 +35,20 @@ USER_FRIENDLY_OUTPUT_LOGGER = "USER_FRIENDLY_OUTPUT"
|
||||
_chat_plugins: list[AutoGPTPluginTemplate] = []
|
||||
|
||||
|
||||
def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
def configure_logging(
|
||||
debug_mode: bool = False,
|
||||
plain_output: bool = False,
|
||||
tts_config: Optional[TTSConfig] = None,
|
||||
log_dir: Path = LOG_DIR,
|
||||
) -> None:
|
||||
"""Configure the native logging module."""
|
||||
|
||||
# create log directory if it doesn't exist
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir()
|
||||
|
||||
log_level = logging.DEBUG if config.debug_mode else logging.INFO
|
||||
log_format = DEBUG_LOG_FORMAT if config.debug_mode else SIMPLE_LOG_FORMAT
|
||||
log_level = logging.DEBUG if debug_mode else logging.INFO
|
||||
log_format = DEBUG_LOG_FORMAT if debug_mode else SIMPLE_LOG_FORMAT
|
||||
console_formatter = AutoGptFormatter(log_format)
|
||||
|
||||
# Console output handlers
|
||||
@@ -61,7 +67,7 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
AutoGptFormatter(SIMPLE_LOG_FORMAT, no_color=True)
|
||||
)
|
||||
|
||||
if config.debug_mode:
|
||||
if debug_mode:
|
||||
# DEBUG log file handler
|
||||
debug_log_handler = logging.FileHandler(log_dir / DEBUG_LOG_FILE, "a", "utf-8")
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
@@ -80,7 +86,7 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
level=log_level,
|
||||
handlers=(
|
||||
[stdout, stderr, activity_log_handler, error_log_handler]
|
||||
+ ([debug_log_handler] if config.debug_mode else [])
|
||||
+ ([debug_log_handler] if debug_mode else [])
|
||||
),
|
||||
)
|
||||
|
||||
@@ -94,9 +100,10 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
user_friendly_output_logger.setLevel(logging.INFO)
|
||||
user_friendly_output_logger.addHandler(
|
||||
typing_console_handler if not config.plain_output else stdout
|
||||
typing_console_handler if not plain_output else stdout
|
||||
)
|
||||
user_friendly_output_logger.addHandler(TTSHandler(config))
|
||||
if tts_config:
|
||||
user_friendly_output_logger.addHandler(TTSHandler(tts_config))
|
||||
user_friendly_output_logger.addHandler(activity_log_handler)
|
||||
user_friendly_output_logger.addHandler(error_log_handler)
|
||||
user_friendly_output_logger.addHandler(stderr)
|
||||
@@ -104,7 +111,8 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
|
||||
|
||||
speech_output_logger = logging.getLogger(SPEECH_OUTPUT_LOGGER)
|
||||
speech_output_logger.setLevel(logging.INFO)
|
||||
speech_output_logger.addHandler(TTSHandler(config))
|
||||
if tts_config:
|
||||
speech_output_logger.addHandler(TTSHandler(tts_config))
|
||||
speech_output_logger.propagate = False
|
||||
|
||||
# JSON logger with better formatting
|
||||
|
||||
@@ -11,7 +11,7 @@ from autogpt.logs.utils import remove_color_codes
|
||||
from autogpt.speech import TextToSpeechProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
@@ -50,7 +50,7 @@ class TypingConsoleHandler(logging.StreamHandler):
|
||||
class TTSHandler(logging.Handler):
|
||||
"""Output messages to the configured TTS engine (if any)"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: TTSConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tts_provider = TextToSpeechProvider(config)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""This module contains the speech recognition and speech synthesis functions."""
|
||||
from autogpt.speech.say import TextToSpeechProvider
|
||||
from autogpt.speech.say import TextToSpeechProvider, TTSConfig
|
||||
|
||||
__all__ = ["TextToSpeechProvider"]
|
||||
__all__ = ["TextToSpeechProvider", "TTSConfig"]
|
||||
|
||||
@@ -4,10 +4,6 @@ from __future__ import annotations
|
||||
import abc
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
|
||||
class VoiceBase:
|
||||
@@ -15,7 +11,7 @@ class VoiceBase:
|
||||
Base class for all voice classes.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initialize the voice class.
|
||||
"""
|
||||
@@ -24,7 +20,7 @@ class VoiceBase:
|
||||
self._api_key = None
|
||||
self._voices = []
|
||||
self._mutex = Lock()
|
||||
self._setup(config)
|
||||
self._setup(*args, **kwargs)
|
||||
|
||||
def say(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
@@ -43,7 +39,7 @@ class VoiceBase:
|
||||
return self._speech(text, voice_index)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _setup(self, config: Config) -> None:
|
||||
def _setup(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Setup the voices, API key, etc.
|
||||
"""
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
|
||||
from .base import VoiceBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,10 +16,15 @@ logger = logging.getLogger(__name__)
|
||||
PLACEHOLDERS = {"your-voice-id"}
|
||||
|
||||
|
||||
class ElevenLabsConfig(SystemConfiguration):
|
||||
api_key: str = UserConfigurable()
|
||||
voice_id: str = UserConfigurable()
|
||||
|
||||
|
||||
class ElevenLabsSpeech(VoiceBase):
|
||||
"""ElevenLabs speech class"""
|
||||
|
||||
def _setup(self, config: Config) -> None:
|
||||
def _setup(self, config: ElevenLabsConfig) -> None:
|
||||
"""Set up the voices, API key, etc.
|
||||
|
||||
Returns:
|
||||
@@ -41,12 +45,12 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
}
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": config.elevenlabs_api_key,
|
||||
"xi-api-key": config.api_key,
|
||||
}
|
||||
self._voices = default_voices.copy()
|
||||
if config.elevenlabs_voice_id in voice_options:
|
||||
config.elevenlabs_voice_id = voice_options[config.elevenlabs_voice_id]
|
||||
self._use_custom_voice(config.elevenlabs_voice_id, 0)
|
||||
if config.voice_id in voice_options:
|
||||
config.voice_id = voice_options[config.voice_id]
|
||||
self._use_custom_voice(config.voice_id, 0)
|
||||
|
||||
def _use_custom_voice(self, voice, voice_index) -> None:
|
||||
"""Use a custom voice if provided and not a placeholder
|
||||
|
||||
@@ -2,21 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import gtts
|
||||
from playsound import playsound
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.speech.base import VoiceBase
|
||||
|
||||
|
||||
class GTTSVoice(VoiceBase):
|
||||
"""GTTS Voice."""
|
||||
|
||||
def _setup(self, config: Config) -> None:
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.speech.base import VoiceBase
|
||||
|
||||
@@ -13,7 +9,7 @@ from autogpt.speech.base import VoiceBase
|
||||
class MacOSTTS(VoiceBase):
|
||||
"""MacOS TTS Voice."""
|
||||
|
||||
def _setup(self, config: Config) -> None:
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
|
||||
@@ -3,24 +3,32 @@ from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from threading import Semaphore
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.configuration.schema import SystemConfiguration, UserConfigurable
|
||||
|
||||
from .base import VoiceBase
|
||||
from .eleven_labs import ElevenLabsSpeech
|
||||
from .eleven_labs import ElevenLabsConfig, ElevenLabsSpeech
|
||||
from .gtts import GTTSVoice
|
||||
from .macos_tts import MacOSTTS
|
||||
from .stream_elements_speech import StreamElementsSpeech
|
||||
from .stream_elements_speech import StreamElementsConfig, StreamElementsSpeech
|
||||
|
||||
_QUEUE_SEMAPHORE = Semaphore(
|
||||
1
|
||||
) # The amount of sounds to queue before blocking the main thread
|
||||
|
||||
|
||||
class TTSConfig(SystemConfiguration):
|
||||
speak_mode: bool = False
|
||||
provider: Literal[
|
||||
"elevenlabs", "gtts", "macos", "streamelements"
|
||||
] = UserConfigurable(default="gtts")
|
||||
elevenlabs: Optional[ElevenLabsConfig] = None
|
||||
streamelements: Optional[StreamElementsConfig] = None
|
||||
|
||||
|
||||
class TextToSpeechProvider:
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: TTSConfig):
|
||||
self._config = config
|
||||
self._default_voice_engine, self._voice_engine = self._get_voice_engine(config)
|
||||
|
||||
@@ -37,19 +45,19 @@ class TextToSpeechProvider:
|
||||
thread.start()
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(enabled={self._config.speak_mode}, provider={self._voice_engine.__class__.__name__})"
|
||||
return f"{self.__class__.__name__}(provider={self._voice_engine.__class__.__name__})"
|
||||
|
||||
@staticmethod
|
||||
def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]:
|
||||
def _get_voice_engine(config: TTSConfig) -> tuple[VoiceBase, VoiceBase]:
|
||||
"""Get the voice engine to use for the given configuration"""
|
||||
tts_provider = config.text_to_speech_provider
|
||||
tts_provider = config.provider
|
||||
if tts_provider == "elevenlabs":
|
||||
voice_engine = ElevenLabsSpeech(config)
|
||||
voice_engine = ElevenLabsSpeech(config.elevenlabs)
|
||||
elif tts_provider == "macos":
|
||||
voice_engine = MacOSTTS(config)
|
||||
voice_engine = MacOSTTS()
|
||||
elif tts_provider == "streamelements":
|
||||
voice_engine = StreamElementsSpeech(config)
|
||||
voice_engine = StreamElementsSpeech(config.streamelements)
|
||||
else:
|
||||
voice_engine = GTTSVoice(config)
|
||||
voice_engine = GTTSVoice()
|
||||
|
||||
return GTTSVoice(config), voice_engine
|
||||
return GTTSVoice(), voice_engine
|
||||
|
||||
@@ -2,28 +2,29 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.speech.base import VoiceBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamElementsConfig(SystemConfiguration):
|
||||
voice: str = UserConfigurable(default="Brian")
|
||||
|
||||
|
||||
class StreamElementsSpeech(VoiceBase):
|
||||
"""Streamelements speech module for autogpt"""
|
||||
|
||||
def _setup(self, config: Config) -> None:
|
||||
def _setup(self, config: StreamElementsConfig) -> None:
|
||||
"""Setup the voices, API key, etc."""
|
||||
self.config = config
|
||||
|
||||
def _speech(self, text: str, voice: str, _: int = 0) -> bool:
|
||||
voice = self.config.streamelements_voice
|
||||
voice = self.config.voice
|
||||
"""Speak text using the streamelements API
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def validate_yaml_file(file: str):
|
||||
def validate_yaml_file(file: str | Path):
|
||||
try:
|
||||
with open(file, encoding="utf-8") as fp:
|
||||
yaml.load(fp.read(), Loader=yaml.FullLoader)
|
||||
|
||||
@@ -38,7 +38,7 @@ def workspace(workspace_root: Path) -> Workspace:
|
||||
def temp_plugins_config_file():
|
||||
"""Create a plugins_config.yaml file in a temp directory so that it doesn't mess with existing ones"""
|
||||
config_directory = TemporaryDirectory()
|
||||
config_file = os.path.join(config_directory.name, "plugins_config.yaml")
|
||||
config_file = Path(config_directory.name) / "plugins_config.yaml"
|
||||
with open(config_file, "w+") as f:
|
||||
f.write(yaml.dump({}))
|
||||
|
||||
@@ -46,7 +46,7 @@ def temp_plugins_config_file():
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config(temp_plugins_config_file: str, mocker: MockerFixture, workspace: Workspace):
|
||||
def config(temp_plugins_config_file: Path, mocker: MockerFixture, workspace: Workspace):
|
||||
config = ConfigBuilder.build_config_from_env(workspace.root.parent)
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = "sk-dummy"
|
||||
@@ -79,7 +79,11 @@ def config(temp_plugins_config_file: str, mocker: MockerFixture, workspace: Work
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def setup_logger(config: Config):
|
||||
configure_logging(config, Path(__file__).parent / "logs")
|
||||
configure_logging(
|
||||
debug_mode=config.debug_mode,
|
||||
plain_output=config.plain_output,
|
||||
log_dir=Path(__file__).parent / "logs",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, create_config
|
||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
||||
from autogpt.config import Config, ConfigBuilder
|
||||
from autogpt.workspace.workspace import Workspace
|
||||
|
||||
@@ -18,9 +18,9 @@ def test_initial_values(config: Config) -> None:
|
||||
"""
|
||||
Test if the initial values of the config class attributes are set correctly.
|
||||
"""
|
||||
assert config.debug_mode == False
|
||||
assert config.continuous_mode == False
|
||||
assert config.speak_mode == False
|
||||
assert config.debug_mode is False
|
||||
assert config.continuous_mode is False
|
||||
assert config.tts_config.speak_mode is False
|
||||
assert config.fast_llm == "gpt-3.5-turbo-16k"
|
||||
assert config.smart_llm == "gpt-4-0314"
|
||||
|
||||
@@ -33,7 +33,7 @@ def test_set_continuous_mode(config: Config) -> None:
|
||||
continuous_mode = config.continuous_mode
|
||||
|
||||
config.continuous_mode = True
|
||||
assert config.continuous_mode == True
|
||||
assert config.continuous_mode is True
|
||||
|
||||
# Reset continuous mode
|
||||
config.continuous_mode = continuous_mode
|
||||
@@ -44,13 +44,13 @@ def test_set_speak_mode(config: Config) -> None:
|
||||
Test if the set_speak_mode() method updates the speak_mode attribute.
|
||||
"""
|
||||
# Store speak mode to reset it after the test
|
||||
speak_mode = config.speak_mode
|
||||
speak_mode = config.tts_config.speak_mode
|
||||
|
||||
config.speak_mode = True
|
||||
assert config.speak_mode == True
|
||||
config.tts_config.speak_mode = True
|
||||
assert config.tts_config.speak_mode is True
|
||||
|
||||
# Reset speak mode
|
||||
config.speak_mode = speak_mode
|
||||
config.tts_config.speak_mode = speak_mode
|
||||
|
||||
|
||||
def test_set_fast_llm(config: Config) -> None:
|
||||
@@ -89,7 +89,7 @@ def test_set_debug_mode(config: Config) -> None:
|
||||
debug_mode = config.debug_mode
|
||||
|
||||
config.debug_mode = True
|
||||
assert config.debug_mode == True
|
||||
assert config.debug_mode is True
|
||||
|
||||
# Reset debug mode
|
||||
config.debug_mode = debug_mode
|
||||
@@ -98,7 +98,7 @@ def test_set_debug_mode(config: Config) -> None:
|
||||
@patch("openai.Model.list")
|
||||
def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config) -> None:
|
||||
"""
|
||||
Test if models update to gpt-3.5-turbo if both are set to gpt-4.
|
||||
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
||||
"""
|
||||
fast_llm = config.fast_llm
|
||||
smart_llm = config.smart_llm
|
||||
@@ -108,21 +108,10 @@ def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config)
|
||||
|
||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
||||
|
||||
create_config(
|
||||
apply_overrides_to_config(
|
||||
config=config,
|
||||
continuous=False,
|
||||
continuous_limit=False,
|
||||
ai_settings_file="",
|
||||
prompt_settings_file="",
|
||||
skip_reprompt=False,
|
||||
speak=False,
|
||||
debug=False,
|
||||
gpt3only=False,
|
||||
gpt4only=False,
|
||||
memory_type="",
|
||||
browser_name="",
|
||||
allow_downloads=False,
|
||||
skip_news=False,
|
||||
)
|
||||
|
||||
assert config.fast_llm == "gpt-3.5-turbo"
|
||||
@@ -136,10 +125,10 @@ def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config)
|
||||
def test_missing_azure_config(workspace: Workspace) -> None:
|
||||
config_file = workspace.get_path("azure_config.yaml")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
ConfigBuilder.load_azure_config(str(config_file))
|
||||
ConfigBuilder.load_azure_config(config_file)
|
||||
|
||||
config_file.write_text("")
|
||||
azure_config = ConfigBuilder.load_azure_config(str(config_file))
|
||||
azure_config = ConfigBuilder.load_azure_config(config_file)
|
||||
|
||||
assert azure_config["openai_api_type"] == "azure"
|
||||
assert azure_config["openai_api_base"] == ""
|
||||
@@ -149,7 +138,7 @@ def test_missing_azure_config(workspace: Workspace) -> None:
|
||||
|
||||
def test_azure_config(config: Config, workspace: Workspace) -> None:
|
||||
config_file = workspace.get_path("azure_config.yaml")
|
||||
yaml_content = f"""
|
||||
yaml_content = """
|
||||
azure_api_type: azure
|
||||
azure_api_base: https://dummy.openai.azure.com
|
||||
azure_api_version: 2023-06-01-preview
|
||||
@@ -209,21 +198,9 @@ azure_model_map:
|
||||
def test_create_config_gpt4only(config: Config) -> None:
|
||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||
mock_get_models.return_value = [{"id": GPT_4_MODEL}]
|
||||
create_config(
|
||||
apply_overrides_to_config(
|
||||
config=config,
|
||||
continuous=False,
|
||||
continuous_limit=None,
|
||||
ai_settings_file=None,
|
||||
prompt_settings_file=None,
|
||||
skip_reprompt=False,
|
||||
speak=False,
|
||||
debug=False,
|
||||
gpt3only=False,
|
||||
gpt4only=True,
|
||||
memory_type=None,
|
||||
browser_name=None,
|
||||
allow_downloads=False,
|
||||
skip_news=False,
|
||||
)
|
||||
assert config.fast_llm == GPT_4_MODEL
|
||||
assert config.smart_llm == GPT_4_MODEL
|
||||
@@ -232,21 +209,9 @@ def test_create_config_gpt4only(config: Config) -> None:
|
||||
def test_create_config_gpt3only(config: Config) -> None:
|
||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||
mock_get_models.return_value = [{"id": GPT_3_MODEL}]
|
||||
create_config(
|
||||
apply_overrides_to_config(
|
||||
config=config,
|
||||
continuous=False,
|
||||
continuous_limit=None,
|
||||
ai_settings_file=None,
|
||||
prompt_settings_file=None,
|
||||
skip_reprompt=False,
|
||||
speak=False,
|
||||
debug=False,
|
||||
gpt3only=True,
|
||||
gpt4only=False,
|
||||
memory_type=None,
|
||||
browser_name=None,
|
||||
allow_downloads=False,
|
||||
skip_news=False,
|
||||
)
|
||||
assert config.fast_llm == GPT_3_MODEL
|
||||
assert config.smart_llm == GPT_3_MODEL
|
||||
|
||||
Reference in New Issue
Block a user