mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Vector memory revamp (part 1: refactoring) (#4208)
Additional changes: * Improve typing * Modularize message history memory & fix/refactor lots of things * Fix summarization * Move memory relevance calculation to MemoryItem & improve test * Fix import warnings in web_selenium.py * Remove `memory_add` ghost command * Implement overlap in `split_text` * Move memory tests into subdirectory * Remove deprecated `get_ada_embedding()` and helpers * Fix used token calculation in `chat_with_ai` * Replace Message TypedDict by dataclass * Fix AgentManager singleton issues in tests --------- Co-authored-by: Auto-GPT-Bot <github-bot@agpt.co>
This commit is contained in:
committed by
GitHub
parent
10489e0df2
commit
bfbe613960
@@ -1,78 +0,0 @@
|
||||
# Generated by CodiumAI
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from autogpt.llm import create_chat_message, generate_context
|
||||
|
||||
|
||||
def test_happy_path_role_content():
|
||||
"""Test that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content."""
|
||||
result = create_chat_message("system", "Hello, world!")
|
||||
assert result == {"role": "system", "content": "Hello, world!"}
|
||||
|
||||
|
||||
def test_empty_role_content():
|
||||
"""Test that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content."""
|
||||
result = create_chat_message("", "")
|
||||
assert result == {"role": "", "content": ""}
|
||||
|
||||
|
||||
def test_generate_context_empty_inputs(mocker):
|
||||
"""Test the behavior of the generate_context function when all input parameters are empty."""
|
||||
# Mock the time.strftime function to return a fixed value
|
||||
mocker.patch("time.strftime", return_value="Sat Apr 15 00:00:00 2023")
|
||||
# Arrange
|
||||
prompt = ""
|
||||
relevant_memory = ""
|
||||
full_message_history = []
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
|
||||
# Act
|
||||
result = generate_context(prompt, relevant_memory, full_message_history, model)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
-1,
|
||||
32,
|
||||
2,
|
||||
[
|
||||
{"role": "system", "content": ""},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"The current time and date is {time.strftime('%c')}",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_generate_context_valid_inputs():
|
||||
"""Test that the function successfully generates a current_context given valid inputs."""
|
||||
# Given
|
||||
prompt = "What is your favorite color?"
|
||||
relevant_memory = "You once painted your room blue."
|
||||
full_message_history = [
|
||||
create_chat_message("user", "Hi there!"),
|
||||
create_chat_message("assistant", "Hello! How can I assist you today?"),
|
||||
create_chat_message("user", "Can you tell me a joke?"),
|
||||
create_chat_message(
|
||||
"assistant",
|
||||
"Why did the tomato turn red? Because it saw the salad dressing!",
|
||||
),
|
||||
create_chat_message("user", "Haha, that's funny."),
|
||||
]
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
|
||||
# When
|
||||
result = generate_context(prompt, relevant_memory, full_message_history, model)
|
||||
|
||||
# Then
|
||||
assert isinstance(result[0], int)
|
||||
assert isinstance(result[1], int)
|
||||
assert isinstance(result[2], int)
|
||||
assert isinstance(result[3], list)
|
||||
assert result[0] >= 0
|
||||
assert result[2] >= 0
|
||||
assert result[1] >= 0
|
||||
assert len(result[3]) >= 2 # current_context should have at least 2 messages
|
||||
assert result[1] <= 2048 # token limit for GPT-3.5-turbo-0301 is 2048 tokens
|
||||
@@ -13,6 +13,8 @@ from pytest_mock import MockerFixture
|
||||
|
||||
import autogpt.commands.file_operations as file_ops
|
||||
from autogpt.config import Config
|
||||
from autogpt.memory.vector.memory_item import MemoryItem
|
||||
from autogpt.memory.vector.utils import Embedding
|
||||
from autogpt.utils import readable_file_size
|
||||
from autogpt.workspace import Workspace
|
||||
|
||||
@@ -22,6 +24,23 @@ def file_content():
|
||||
return "This is a test file.\n"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_MemoryItem_from_text(mocker: MockerFixture, mock_embedding: Embedding):
|
||||
mocker.patch.object(
|
||||
file_ops.MemoryItem,
|
||||
"from_text",
|
||||
new=lambda content, source_type, metadata: MemoryItem(
|
||||
raw_content=content,
|
||||
summary=f"Summary of content '{content}'",
|
||||
chunk_summaries=[f"Summary of content '{content}'"],
|
||||
chunks=[content],
|
||||
e_summary=mock_embedding,
|
||||
e_chunks=[mock_embedding],
|
||||
metadata=metadata | {"source_type": source_type},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file_path(config, workspace: Workspace):
|
||||
return workspace.get_path("test_file.txt")
|
||||
@@ -188,7 +207,11 @@ def test_split_file(max_length, overlap, content, expected):
|
||||
)
|
||||
|
||||
|
||||
def test_read_file(test_file_with_content_path: Path, file_content):
|
||||
def test_read_file(
|
||||
mock_MemoryItem_from_text,
|
||||
test_file_with_content_path: Path,
|
||||
file_content,
|
||||
):
|
||||
content = file_ops.read_file(test_file_with_content_path)
|
||||
assert content == file_content
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
|
||||
from autogpt.agent.agent import Agent
|
||||
from autogpt.config import AIConfig
|
||||
from autogpt.llm import create_chat_completion
|
||||
from autogpt.llm.chat import create_chat_completion
|
||||
from autogpt.log_cycle.log_cycle import LogCycleHandler
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import InvalidRequestError
|
||||
from openai.error import APIError, RateLimitError
|
||||
|
||||
from autogpt.llm import llm_utils
|
||||
from autogpt.llm.llm_utils import check_model
|
||||
from autogpt.llm import utils as llm_utils
|
||||
|
||||
|
||||
@pytest.fixture(params=[RateLimitError, APIError])
|
||||
@@ -107,36 +105,6 @@ def test_retry_openapi_other_api_error(capsys):
|
||||
assert output.out == ""
|
||||
|
||||
|
||||
def test_chunked_tokens():
|
||||
text = "Auto-GPT is an experimental open-source application showcasing the capabilities of the GPT-4 language model"
|
||||
expected_output = [
|
||||
(
|
||||
13556,
|
||||
12279,
|
||||
2898,
|
||||
374,
|
||||
459,
|
||||
22772,
|
||||
1825,
|
||||
31874,
|
||||
3851,
|
||||
67908,
|
||||
279,
|
||||
17357,
|
||||
315,
|
||||
279,
|
||||
480,
|
||||
2898,
|
||||
12,
|
||||
19,
|
||||
4221,
|
||||
1646,
|
||||
)
|
||||
]
|
||||
output = list(llm_utils.chunked_tokens(text, "cl100k_base", 8191))
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_check_model(api_manager):
|
||||
"""
|
||||
Test if check_model() returns original model when valid.
|
||||
@@ -145,7 +113,7 @@ def test_check_model(api_manager):
|
||||
with patch("openai.Model.list") as mock_list_models:
|
||||
# Test when correct model is returned
|
||||
mock_list_models.return_value = {"data": [{"id": "gpt-4"}]}
|
||||
result = check_model("gpt-4", "smart_llm_model")
|
||||
result = llm_utils.check_model("gpt-4", "smart_llm_model")
|
||||
assert result == "gpt-4"
|
||||
|
||||
# Reset api manager models
|
||||
@@ -153,7 +121,7 @@ def test_check_model(api_manager):
|
||||
|
||||
# Test when incorrect model is returned
|
||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
||||
result = check_model("gpt-4", "fast_llm_model")
|
||||
result = llm_utils.check_model("gpt-4", "fast_llm_model")
|
||||
assert result == "gpt-3.5-turbo"
|
||||
|
||||
# Reset api manager models
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.plugins import (
|
||||
denylist_allowlist_check,
|
||||
inspect_zip_for_modules,
|
||||
scan_plugins,
|
||||
)
|
||||
from autogpt.plugins import denylist_allowlist_check, inspect_zip_for_modules
|
||||
|
||||
PLUGINS_TEST_DIR = "tests/unit/data/test_plugins"
|
||||
PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
|
||||
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
|
||||
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
|
||||
|
||||
|
||||
def test_inspect_zip_for_modules():
|
||||
@@ -77,54 +72,3 @@ def test_denylist_allowlist_check_user_input_invalid(
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_plugins():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
# Test that the function returns the correct number of plugins
|
||||
cfg = Config()
|
||||
cfg.plugins_dir = PLUGINS_TEST_DIR
|
||||
cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_openai_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
plugins_dir = PLUGINS_TEST_DIR
|
||||
plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
plugins_denylist = ["AutoGPTPVicuna"]
|
||||
plugins_allowlist = [PLUGIN_TEST_OPENAI]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_openai(mock_config_openai_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_openai_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_generic_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
# Test that the function returns the correct number of plugins
|
||||
class MockConfig:
|
||||
plugins_dir = PLUGINS_TEST_DIR
|
||||
plugins_openai = []
|
||||
plugins_denylist = []
|
||||
plugins_allowlist = ["AutoGPTPVicuna"]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_generic(mock_config_generic_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_generic_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
55
tests/unit/test_token_counter.py
Normal file
55
tests/unit/test_token_counter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.llm.base import Message
|
||||
from autogpt.llm.utils import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
def test_count_message_tokens():
|
||||
messages = [
|
||||
Message("user", "Hello"),
|
||||
Message("assistant", "Hi there!"),
|
||||
]
|
||||
assert count_message_tokens(messages) == 17
|
||||
|
||||
|
||||
def test_count_message_tokens_empty_input():
|
||||
"""Empty input should return 3 tokens"""
|
||||
assert count_message_tokens([]) == 3
|
||||
|
||||
|
||||
def test_count_message_tokens_invalid_model():
|
||||
"""Invalid model should raise a NotImplementedError"""
|
||||
messages = [
|
||||
Message("user", "Hello"),
|
||||
Message("assistant", "Hi there!"),
|
||||
]
|
||||
with pytest.raises(NotImplementedError):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
|
||||
|
||||
def test_count_message_tokens_gpt_4():
|
||||
messages = [
|
||||
Message("user", "Hello"),
|
||||
Message("assistant", "Hi there!"),
|
||||
]
|
||||
assert count_message_tokens(messages, model="gpt-4-0314") == 15
|
||||
|
||||
|
||||
def test_count_string_tokens():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4
|
||||
|
||||
|
||||
def test_count_string_tokens_empty_input():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0
|
||||
|
||||
|
||||
def test_count_string_tokens_gpt_4():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
|
||||
Reference in New Issue
Block a user