Fix broken tests (casualties from the past few days)

This commit is contained in:
Reinier van der Leer
2023-08-23 02:26:39 +02:00
parent a660619ea8
commit 97ccaba45f
14 changed files with 89 additions and 57 deletions

View File

@@ -347,7 +347,7 @@ def execute_command(
raise CommandExecutionError(str(e))
# Handle non-native commands (e.g. from plugins)
for command in agent.ai_config.prompt_generator.commands:
for command in agent.prompt_generator.commands:
if (
command_name == command.label.lower()
or command_name == command.name.lower()

View File

@@ -53,7 +53,7 @@ class ContextMixin:
0, Message("system", "# Context\n" + self.context.format_numbered())
)
return super(ContextMixin, self).construct_base_prompt(*args, **kwargs)
return super(ContextMixin, self).construct_base_prompt(*args, **kwargs) # type: ignore
def get_agent_context(agent: BaseAgent) -> AgentContext | None:

View File

@@ -19,7 +19,7 @@ from autogpt.command_decorator import command
from autogpt.memory.vector import MemoryItem, VectorMemory
from .decorators import sanitize_path_arg
from .file_context import open_file, open_folder # NOQA
from .file_context import open_file, open_folder # NOQA
from .file_operations_utils import read_textual_file
logger = logging.getLogger(__name__)

View File

@@ -169,15 +169,15 @@ def retry_api(
warn_user bool: Whether to warn the user. Defaults to True.
"""
error_messages = {
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}",
RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}",
ServiceUnavailableError: "The OpenAI API engine is currently overloaded",
RateLimitError: "Reached rate limit",
}
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
)
backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}"
backoff_msg = "Waiting {backoff} seconds..."
def _wrapper(func: Callable):
@functools.wraps(func)

View File

@@ -12,6 +12,7 @@ from openai.util import logger as openai_logger
if TYPE_CHECKING:
from autogpt.config import Config
from .filters import BelowLevelFilter
from .formatters import AutoGptFormatter
from .handlers import TTSHandler, TypingConsoleHandler
@@ -42,10 +43,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
log_format = DEBUG_LOG_FORMAT if config.debug_mode else SIMPLE_LOG_FORMAT
console_formatter = AutoGptFormatter(log_format)
# Console output handler
console_handler = logging.StreamHandler(stream=sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(console_formatter)
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(log_level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
stdout.setFormatter(console_formatter)
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
stderr.setFormatter(console_formatter)
# INFO log file handler
activity_log_handler = logging.FileHandler(log_dir / LOG_FILE, "a", "utf-8")
@@ -68,7 +73,7 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
format=log_format,
level=log_level,
handlers=(
[console_handler, activity_log_handler, error_log_handler]
[stdout, stderr, activity_log_handler, error_log_handler]
+ ([debug_log_handler] if config.debug_mode else [])
),
)
@@ -81,13 +86,14 @@ def configure_logging(config: Config, log_dir: Path = LOG_DIR) -> None:
typing_console_handler.setFormatter(console_formatter)
user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
user_friendly_output_logger.setLevel(logging.INFO)
user_friendly_output_logger.addHandler(
typing_console_handler if not config.plain_output else console_handler
typing_console_handler if not config.plain_output else stdout
)
user_friendly_output_logger.addHandler(TTSHandler(config))
user_friendly_output_logger.addHandler(activity_log_handler)
user_friendly_output_logger.addHandler(error_log_handler)
user_friendly_output_logger.setLevel(logging.INFO)
user_friendly_output_logger.addHandler(stderr)
user_friendly_output_logger.propagate = False
# JSON logger with better formatting

12
autogpt/logs/filters.py Normal file
View File

@@ -0,0 +1,12 @@
import logging
class BelowLevelFilter(logging.Filter):
"""Filter for logging levels below a certain threshold."""
def __init__(self, below_level: int):
super().__init__()
self.below_level = below_level
def filter(self, record: logging.LogRecord):
return record.levelno < self.below_level

View File

@@ -11,6 +11,7 @@ def user_friendly_output(
level: int = logging.INFO,
title: str = "",
title_color: str = "",
preserve_message_color: bool = False,
) -> None:
"""Outputs a message to the user in a user-friendly way.
@@ -24,7 +25,15 @@ def user_friendly_output(
for plugin in _chat_plugins:
plugin.report(f"{title}: {message}")
logger.log(level, message, extra={"title": title, "title_color": title_color})
logger.log(
level,
message,
extra={
"title": title,
"title_color": title_color,
"preserve_color": preserve_message_color,
},
)
def print_attribute(
@@ -51,5 +60,8 @@ def request_user_double_check(additionalText: Optional[str] = None) -> None:
)
user_friendly_output(
additionalText, level=logging.WARN, title="DOUBLE CHECK CONFIGURATION"
additionalText,
level=logging.WARN,
title="DOUBLE CHECK CONFIGURATION",
preserve_message_color=True,
)

View File

@@ -93,9 +93,9 @@ class CommandRegistry:
if name in self.commands_aliases:
return self.commands_aliases[name]
def call(self, command_name: str, **kwargs) -> Any:
def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any:
if command := self.get_command(command_name):
return command(**kwargs)
return command(**kwargs, agent=agent)
raise KeyError(f"Command '{command_name}' not found in registry")
def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]:

View File

@@ -57,7 +57,6 @@ def config(
config.plugins_dir = "tests/unit/data/test_plugins"
config.plugins_config_file = temp_plugins_config_file
# HACK: this is necessary to ensure PLAIN_OUTPUT takes effect
config.plain_output = True
configure_logging(config, Path(__file__).parent / "logs")
@@ -95,7 +94,6 @@ def agent(config: Config) -> Agent:
)
command_registry = CommandRegistry()
ai_config.command_registry = command_registry
config.memory_backend = "json_file"
memory_json_file = get_memory(config)
memory_json_file.clear()

View File

@@ -11,7 +11,7 @@ def test_agent_initialization(agent: Agent):
def test_execute_command_plugin(agent: Agent):
"""Test that executing a command that came from a plugin works as expected"""
command_name = "check_plan"
agent.ai_config.prompt_generator.add_command(
agent.prompt_generator.add_command(
command_name,
"Read the plan.md with the next goals to achieve",
{},

View File

@@ -55,8 +55,6 @@ def test_ai_config_file_not_exists(workspace):
assert ai_config.ai_role == ""
assert ai_config.ai_goals == []
assert ai_config.api_budget == 0.0
assert ai_config.prompt_generator is None
assert ai_config.command_registry is None
def test_ai_config_file_is_empty(workspace):
@@ -70,5 +68,3 @@ def test_ai_config_file_is_empty(workspace):
assert ai_config.ai_role == ""
assert ai_config.ai_goals == []
assert ai_config.api_budget == 0.0
assert ai_config.prompt_generator is None
assert ai_config.command_registry is None

View File

@@ -1,10 +1,16 @@
from __future__ import annotations
import os
import shutil
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from autogpt.agents import Agent, BaseAgent
from autogpt.models.command import Command, CommandParameter
from autogpt.models.command_registry import CommandRegistry
@@ -14,7 +20,7 @@ PARAMETERS = [
]
def example_command_method(arg1: int, arg2: str) -> str:
def example_command_method(arg1: int, arg2: str, agent: BaseAgent) -> str:
"""Example function for testing the Command class."""
# This function is static because it is not used by any other test cases.
return f"{arg1} - {arg2}"
@@ -47,16 +53,16 @@ def example_command():
)
def test_command_call(example_command: Command):
def test_command_call(example_command: Command, agent: Agent):
"""Test that Command(*args) calls and returns the result of method(*args)."""
result = example_command(arg1=1, arg2="test")
result = example_command(arg1=1, arg2="test", agent=agent)
assert result == "1 - test"
def test_command_call_with_invalid_arguments(example_command: Command):
def test_command_call_with_invalid_arguments(example_command: Command, agent: Agent):
"""Test that calling a Command object with invalid arguments raises a TypeError."""
with pytest.raises(TypeError):
example_command(arg1="invalid", does_not_exist="test")
example_command(arg1="invalid", does_not_exist="test", agent=agent)
def test_register_command(example_command: Command):
@@ -148,7 +154,7 @@ def test_get_nonexistent_command():
assert "nonexistent_command" not in registry
def test_call_command():
def test_call_command(agent: Agent):
"""Test that a command can be called through the registry."""
registry = CommandRegistry()
cmd = Command(
@@ -159,17 +165,17 @@ def test_call_command():
)
registry.register(cmd)
result = registry.call("example", arg1=1, arg2="test")
result = registry.call("example", arg1=1, arg2="test", agent=agent)
assert result == "1 - test"
def test_call_nonexistent_command():
def test_call_nonexistent_command(agent: Agent):
"""Test that attempting to call a nonexistent command raises a KeyError."""
registry = CommandRegistry()
with pytest.raises(KeyError):
registry.call("nonexistent_command", arg1=1, arg2="test")
registry.call("nonexistent_command", arg1=1, arg2="test", agent=agent)
def test_import_mock_commands_module():

View File

@@ -18,7 +18,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
repo = "github.com/Significant-Gravitas/Auto-GPT.git"
scheme = "https://"
url = scheme + repo
clone_path = str(workspace.get_path("auto-gpt-repo"))
clone_path = workspace.get_path("auto-gpt-repo")
expected_output = f"Cloned {url} to {clone_path}"
@@ -33,7 +33,7 @@ def test_clone_auto_gpt_repository(workspace, mock_clone_from, agent: Agent):
def test_clone_repository_error(workspace, mock_clone_from, agent: Agent):
url = "https://github.com/this-repository/does-not-exist.git"
clone_path = str(workspace.get_path("does-not-exist"))
clone_path = workspace.get_path("does-not-exist")
mock_clone_from.side_effect = GitCommandError(
"clone", "fatal: repository not found", ""

View File

@@ -31,7 +31,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
return RaisesError()
def test_retry_open_api_no_error(capsys):
def test_retry_open_api_no_error(caplog: pytest.LogCaptureFixture):
"""Tests the retry functionality with no errors expected"""
@openai.retry_api()
@@ -41,9 +41,9 @@ def test_retry_open_api_no_error(capsys):
result = f()
assert result == 1
output = capsys.readouterr()
assert output.out == ""
assert output.err == ""
output = caplog.text
assert output == ""
assert output == ""
@pytest.mark.parametrize(
@@ -51,7 +51,9 @@ def test_retry_open_api_no_error(capsys):
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
)
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
def test_retry_open_api_passing(
caplog: pytest.LogCaptureFixture, error, error_count, retry_count, failure
):
"""Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass"""
call_count = min(error_count, retry_count) + 1
@@ -65,20 +67,20 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
assert raises.count == call_count
output = capsys.readouterr()
output = caplog.text
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit" in output.out
assert "Please double check" in output.out
assert "Reached rate limit" in output
assert "Please double check" in output
if type(error) == ServiceUnavailableError:
assert "The OpenAI API engine is currently overloaded" in output.out
assert "Please double check" in output.out
assert "The OpenAI API engine is currently overloaded" in output
assert "Please double check" in output
else:
assert output.out == ""
assert output == ""
def test_retry_open_api_rate_limit_no_warn(capsys):
def test_retry_open_api_rate_limit_no_warn(caplog: pytest.LogCaptureFixture):
"""Tests the retry logic with a rate limit error"""
error_count = 2
retry_count = 10
@@ -89,13 +91,13 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
output = caplog.text
assert "Reached rate limit" in output.out
assert "Please double check" not in output.out
assert "Reached rate limit" in output
assert "Please double check" not in output
def test_retry_open_api_service_unavairable_no_warn(capsys):
def test_retry_open_api_service_unavairable_no_warn(caplog: pytest.LogCaptureFixture):
"""Tests the retry logic with a service unavairable error"""
error_count = 2
retry_count = 10
@@ -108,13 +110,13 @@ def test_retry_open_api_service_unavairable_no_warn(capsys):
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
output = caplog.text
assert "The OpenAI API engine is currently overloaded" in output.out
assert "Please double check" not in output.out
assert "The OpenAI API engine is currently overloaded" in output
assert "Please double check" not in output
def test_retry_openapi_other_api_error(capsys):
def test_retry_openapi_other_api_error(caplog: pytest.LogCaptureFixture):
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
error_count = 2
retry_count = 10
@@ -126,5 +128,5 @@ def test_retry_openapi_other_api_error(capsys):
call_count = 1
assert raises.count == call_count
output = capsys.readouterr()
assert output.out == ""
output = caplog.text
assert output == ""