mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(agent, forge): Move tests from autogpt to forge (#7247)
- Move `autogpt/tests/vcr_cassettes` submodule to `forge/tests/vcr_cassettes` - Remove not needed markers from `pyproject.toml`: `"requires_openai_api_key", "requires_huggingface_api_key"` - Update relevant GitHub workflows Moved relevant tests from `autogpt/tests` to appropiate directories: - Component tests to their respective component dirs - `autogpt/tests/unit/test_web_search.py` → `forge/components/web/test_search.py` - `autogpt/tests/unit/test_git_commands.py` → `forge/components/git_operations/test_git_operations.py` - `autogpt/tests/unit/test_file_operations.py` → `forge/components/file_manager/test_file_manager.py` - `autogpt/tests/integration/test_image_gen.py` → `forge/components/image_gen/test_image_gen.py` - `autogpt/tests/integration/test_web_selenium.py` → `forge/components/web/test_selenium.py` - `autogpt/tests/integration/test_execute_code.py` → `forge/components/code_executor/test_code_executor.py` - `autogpt/tests/unit/test_s3_file_storage.py` → `forge/file_storage/test_s3_file_storage.py` - `autogpt/tests/unit/test_gcs_file_storage.py` → `forge/file_storage/test_gcs_file_storage.py` - `autogpt/tests/unit/test_local_file_storage.py` → `forge/file_storage/test_local_file_storage.py` - `autogpt/tests/unit/test_json.py` → `forge/json/test_parsing.py` - `autogpt/tests/unit/test_logs.py` → `forge/logging/test_utils.py` - `autogpt/tests/unit/test_url_validation.py` → `forge/utils/test_url_validator.py` - `autogpt/tests/unit/test_text_file_parsers.py` → `forge/utils/test_file_operations.py` - (Re)moved dependencies from `autogpt/pyproject.toml` that were only used in these test files. Also: - Added `load_env_vars` fixture to `forge/conftest.py` - Fixed a type error in `forge/components/web/test_search.py` - Merged `autogpt/.gitattributes` into root `.gitattributes` --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
committed by
GitHub
parent
7415e24fc3
commit
08612cc3bf
41
forge/conftest.py
Normal file
41
forge/conftest.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from forge.file_storage.base import FileStorage, FileStorageConfiguration
|
||||
from forge.file_storage.local import LocalFileStorage
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.vcr",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def load_env_vars():
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_project_root(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_data_dir(tmp_project_root: Path) -> Path:
|
||||
dir = tmp_project_root / "data"
|
||||
dir.mkdir(parents=True, exist_ok=True)
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage(app_data_dir: Path) -> FileStorage:
|
||||
storage = LocalFileStorage(
|
||||
FileStorageConfiguration(
|
||||
root=Path(f"{app_data_dir}/{str(uuid.uuid4())}"), restrict_to_root=False
|
||||
)
|
||||
)
|
||||
storage.initialize()
|
||||
return storage
|
||||
@@ -12,9 +12,9 @@ from .models.task import StepRequestBody, Task, TaskListResponse, TaskRequestBod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(test_workspace: Path):
|
||||
def agent(tmp_project_root: Path):
|
||||
db = AgentDB("sqlite:///test.db")
|
||||
config = FileStorageConfiguration(root=test_workspace)
|
||||
config = FileStorageConfiguration(root=tmp_project_root)
|
||||
workspace = LocalFileStorage(config)
|
||||
return ProtocolAgent(db, workspace)
|
||||
|
||||
|
||||
0
forge/forge/components/__init__.py
Normal file
0
forge/forge/components/__init__.py
Normal file
@@ -1,8 +1,6 @@
|
||||
from .code_executor import CodeExecutionError, CodeExecutorComponent
|
||||
|
||||
__all__ = [
|
||||
"ALLOWLIST_CONTROL",
|
||||
"DENYLIST_CONTROL",
|
||||
"CodeExecutionError",
|
||||
"CodeExecutorComponent",
|
||||
]
|
||||
|
||||
161
forge/forge/components/code_executor/test_code_executor.py
Normal file
161
forge/forge/components/code_executor/test_code_executor.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.utils.exceptions import InvalidArgumentError, OperationNotAllowedError
|
||||
|
||||
from .code_executor import (
|
||||
CodeExecutorComponent,
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def code_executor_component(storage: FileStorage):
|
||||
return CodeExecutorComponent(storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_code(random_string) -> str:
|
||||
return f"print('Hello {random_string}!')"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def python_test_file(storage: FileStorage, random_code: str):
|
||||
temp_file = tempfile.NamedTemporaryFile(dir=storage.root, suffix=".py")
|
||||
temp_file.write(str.encode(random_code))
|
||||
temp_file.flush()
|
||||
|
||||
yield Path(temp_file.name)
|
||||
temp_file.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def python_test_args_file(storage: FileStorage):
|
||||
temp_file = tempfile.NamedTemporaryFile(dir=storage.root, suffix=".py")
|
||||
temp_file.write(str.encode("import sys\nprint(sys.argv[1], sys.argv[2])"))
|
||||
temp_file.flush()
|
||||
|
||||
yield Path(temp_file.name)
|
||||
temp_file.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_string():
|
||||
return "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
|
||||
|
||||
def test_execute_python_file(
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
python_test_file: Path,
|
||||
random_string: str,
|
||||
):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
result: str = code_executor_component.execute_python_file(python_test_file)
|
||||
assert result.replace("\r", "") == f"Hello {random_string}!\n"
|
||||
|
||||
|
||||
def test_execute_python_file_args(
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
python_test_args_file: Path,
|
||||
random_string: str,
|
||||
):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
random_args = [random_string] * 2
|
||||
random_args_string = " ".join(random_args)
|
||||
result = code_executor_component.execute_python_file(
|
||||
python_test_args_file, args=random_args
|
||||
)
|
||||
assert result == f"{random_args_string}\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_python_code(
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
random_code: str,
|
||||
random_string: str,
|
||||
):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
result: str = await code_executor_component.execute_python_code(random_code)
|
||||
assert result.replace("\r", "") == f"Hello {random_string}!\n"
|
||||
|
||||
|
||||
def test_execute_python_file_invalid(code_executor_component: CodeExecutorComponent):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
code_executor_component.execute_python_file(Path("not_python.txt"))
|
||||
|
||||
|
||||
def test_execute_python_file_not_found(code_executor_component: CodeExecutorComponent):
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match=r"python: can't open file '([a-zA-Z]:)?[/\\\-\w]*notexist.py': "
|
||||
r"\[Errno 2\] No such file or directory",
|
||||
):
|
||||
code_executor_component.execute_python_file(Path("notexist.py"))
|
||||
|
||||
|
||||
def test_execute_shell(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
code_executor_component.config.shell_command_control = "allowlist"
|
||||
code_executor_component.config.shell_allowlist = ["echo"]
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert f"Hello {random_string}!" in result
|
||||
|
||||
|
||||
def test_execute_shell_local_commands_not_allowed(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
with pytest.raises(OperationNotAllowedError, match="not allowed"):
|
||||
code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
|
||||
|
||||
def test_execute_shell_denylist_should_deny(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
code_executor_component.config.shell_command_control = "denylist"
|
||||
code_executor_component.config.shell_denylist = ["echo"]
|
||||
|
||||
with pytest.raises(OperationNotAllowedError, match="not allowed"):
|
||||
code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
|
||||
|
||||
def test_execute_shell_denylist_should_allow(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
code_executor_component.config.shell_command_control = "denylist"
|
||||
code_executor_component.config.shell_denylist = ["cat"]
|
||||
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert "Hello" in result and random_string in result
|
||||
|
||||
|
||||
def test_execute_shell_allowlist_should_deny(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
code_executor_component.config.shell_command_control = "allowlist"
|
||||
code_executor_component.config.shell_allowlist = ["cat"]
|
||||
|
||||
with pytest.raises(OperationNotAllowedError, match="not allowed"):
|
||||
code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
|
||||
|
||||
def test_execute_shell_allowlist_should_allow(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str
|
||||
):
|
||||
code_executor_component.config.shell_command_control = "allowlist"
|
||||
code_executor_component.config.shell_allowlist = ["echo"]
|
||||
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert "Hello" in result and random_string in result
|
||||
118
forge/forge/components/file_manager/test_file_manager.py
Normal file
118
forge/forge/components/file_manager/test_file_manager.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from forge.agent.base import BaseAgentSettings
|
||||
from forge.file_storage import FileStorage
|
||||
|
||||
from . import FileManagerComponent
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_content():
|
||||
return "This is a test file.\n"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_manager_component(storage: FileStorage):
|
||||
return FileManagerComponent(
|
||||
storage,
|
||||
BaseAgentSettings(
|
||||
agent_id="TestAgent", name="TestAgent", description="Test Agent description"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file_name():
|
||||
return Path("test_file.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file_path(test_file_name: Path, storage: FileStorage):
|
||||
return storage.get_path(test_file_name)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_directory(storage: FileStorage):
|
||||
return storage.get_path("test_directory")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_nested_file(storage: FileStorage):
|
||||
return storage.get_path("nested/test_file.txt")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(
|
||||
test_file_path: Path,
|
||||
file_content,
|
||||
file_manager_component: FileManagerComponent,
|
||||
):
|
||||
await file_manager_component.workspace.write_file(test_file_path.name, file_content)
|
||||
content = file_manager_component.read_file(test_file_path.name)
|
||||
assert content.replace("\r", "") == file_content
|
||||
|
||||
|
||||
def test_read_file_not_found(file_manager_component: FileManagerComponent):
|
||||
filename = "does_not_exist.txt"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
file_manager_component.read_file(filename)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_to_file_relative_path(
|
||||
test_file_name: Path, file_manager_component: FileManagerComponent
|
||||
):
|
||||
new_content = "This is new content.\n"
|
||||
await file_manager_component.write_to_file(test_file_name, new_content)
|
||||
with open(
|
||||
file_manager_component.workspace.get_path(test_file_name), "r", encoding="utf-8"
|
||||
) as f:
|
||||
content = f.read()
|
||||
assert content == new_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_to_file_absolute_path(
|
||||
test_file_path: Path, file_manager_component: FileManagerComponent
|
||||
):
|
||||
new_content = "This is new content.\n"
|
||||
await file_manager_component.write_to_file(test_file_path, new_content)
|
||||
with open(test_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert content == new_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files(file_manager_component: FileManagerComponent):
|
||||
# Create files A and B
|
||||
file_a_name = "file_a.txt"
|
||||
file_b_name = "file_b.txt"
|
||||
test_directory = Path("test_directory")
|
||||
|
||||
await file_manager_component.workspace.write_file(file_a_name, "This is file A.")
|
||||
await file_manager_component.workspace.write_file(file_b_name, "This is file B.")
|
||||
|
||||
# Create a subdirectory and place a copy of file_a in it
|
||||
file_manager_component.workspace.make_dir(test_directory)
|
||||
await file_manager_component.workspace.write_file(
|
||||
test_directory / file_a_name, "This is file A in the subdirectory."
|
||||
)
|
||||
|
||||
files = file_manager_component.list_folder(".")
|
||||
assert file_a_name in files
|
||||
assert file_b_name in files
|
||||
assert os.path.join(test_directory, file_a_name) in files
|
||||
|
||||
# Clean up
|
||||
file_manager_component.workspace.delete_file(file_a_name)
|
||||
file_manager_component.workspace.delete_file(file_b_name)
|
||||
file_manager_component.workspace.delete_file(test_directory / file_a_name)
|
||||
file_manager_component.workspace.delete_dir(test_directory)
|
||||
|
||||
# Case 2: Search for a file that does not exist and make sure we don't throw
|
||||
non_existent_file = "non_existent_file.txt"
|
||||
files = file_manager_component.list_folder("")
|
||||
assert non_existent_file not in files
|
||||
57
forge/forge/components/git_operations/test_git_operations.py
Normal file
57
forge/forge/components/git_operations/test_git_operations.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from git.exc import GitCommandError
|
||||
from git.repo.base import Repo
|
||||
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.utils.exceptions import CommandExecutionError
|
||||
|
||||
from . import GitOperationsComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clone_from(mocker):
|
||||
return mocker.patch.object(Repo, "clone_from")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_ops_component():
|
||||
return GitOperationsComponent()
|
||||
|
||||
|
||||
def test_clone_auto_gpt_repository(
|
||||
git_ops_component: GitOperationsComponent,
|
||||
storage: FileStorage,
|
||||
mock_clone_from,
|
||||
):
|
||||
mock_clone_from.return_value = None
|
||||
|
||||
repo = "github.com/Significant-Gravitas/Auto-GPT.git"
|
||||
scheme = "https://"
|
||||
url = scheme + repo
|
||||
clone_path = storage.get_path("auto-gpt-repo")
|
||||
|
||||
expected_output = f"Cloned {url} to {clone_path}"
|
||||
|
||||
clone_result = git_ops_component.clone_repository(url, clone_path)
|
||||
|
||||
assert clone_result == expected_output
|
||||
mock_clone_from.assert_called_once_with(
|
||||
url=f"{scheme}{git_ops_component.config.github_username}:{git_ops_component.config.github_api_key}@{repo}", # noqa: E501
|
||||
to_path=clone_path,
|
||||
)
|
||||
|
||||
|
||||
def test_clone_repository_error(
|
||||
git_ops_component: GitOperationsComponent,
|
||||
storage: FileStorage,
|
||||
mock_clone_from,
|
||||
):
|
||||
url = "https://github.com/this-repository/does-not-exist.git"
|
||||
clone_path = storage.get_path("does-not-exist")
|
||||
|
||||
mock_clone_from.side_effect = GitCommandError(
|
||||
"clone", "fatal: repository not found", ""
|
||||
)
|
||||
|
||||
with pytest.raises(CommandExecutionError):
|
||||
git_ops_component.clone_repository(url, clone_path)
|
||||
247
forge/forge/components/image_gen/test_image_gen.py
Normal file
247
forge/forge/components/image_gen/test_image_gen.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import functools
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from pydantic import SecretStr, ValidationError
|
||||
|
||||
from forge.components.image_gen import ImageGeneratorComponent
|
||||
from forge.components.image_gen.image_gen import ImageGeneratorConfiguration
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.llm.providers.openai import OpenAICredentials
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_gen_component(storage: FileStorage):
|
||||
try:
|
||||
cred = OpenAICredentials.from_env()
|
||||
except ValidationError:
|
||||
cred = OpenAICredentials(api_key=SecretStr("test"))
|
||||
|
||||
return ImageGeneratorComponent(storage, openai_credentials=cred)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huggingface_image_gen_component(storage: FileStorage):
|
||||
config = ImageGeneratorConfiguration(
|
||||
image_provider="huggingface",
|
||||
huggingface_api_token=SecretStr("1"),
|
||||
huggingface_image_model="CompVis/stable-diffusion-v1-4",
|
||||
)
|
||||
return ImageGeneratorComponent(storage, config=config)
|
||||
|
||||
|
||||
@pytest.fixture(params=[256, 512, 1024])
|
||||
def image_size(request):
|
||||
"""Parametrize image size."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_dalle(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
image_size,
|
||||
):
|
||||
"""Test DALL-E image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
image_provider="dalle",
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="The image is too big to be put in a cassette for a CI pipeline. "
|
||||
"We're looking into a solution."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"image_model",
|
||||
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
||||
)
|
||||
def test_huggingface(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
image_size,
|
||||
image_model,
|
||||
):
|
||||
"""Test HuggingFace image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
image_provider="huggingface",
|
||||
image_size=image_size,
|
||||
hugging_face_image_model=image_model,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui(image_gen_component: ImageGeneratorComponent, image_size):
|
||||
"""Test SD WebUI image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
image_provider="sd_webui",
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui_negative_prompt(
|
||||
image_gen_component: ImageGeneratorComponent, image_size
|
||||
):
|
||||
gen_image = functools.partial(
|
||||
image_gen_component.generate_image_with_sd_webui,
|
||||
prompt="astronaut riding a horse",
|
||||
size=image_size,
|
||||
extra={"seed": 123},
|
||||
)
|
||||
|
||||
# Generate an image with a negative prompt
|
||||
image_path = lst(
|
||||
gen_image(negative_prompt="horse", output_file=Path("negative.jpg"))
|
||||
)
|
||||
with Image.open(image_path) as img:
|
||||
neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
|
||||
# Generate an image without a negative prompt
|
||||
image_path = lst(gen_image(output_file=Path("positive.jpg")))
|
||||
with Image.open(image_path) as img:
|
||||
image_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
|
||||
assert image_hash != neg_image_hash
|
||||
|
||||
|
||||
def lst(txt):
|
||||
"""Extract the file path from the output of `generate_image()`"""
|
||||
return Path(txt.split(": ", maxsplit=1)[1].strip())
|
||||
|
||||
|
||||
def generate_and_validate(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
image_size,
|
||||
image_provider,
|
||||
hugging_face_image_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate an image and validate the output."""
|
||||
image_gen_component.config.image_provider = image_provider
|
||||
if hugging_face_image_model:
|
||||
image_gen_component.config.huggingface_image_model = hugging_face_image_model
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
image_path = lst(image_gen_component.generate_image(prompt, image_size, **kwargs))
|
||||
assert image_path.exists()
|
||||
with Image.open(image_path) as img:
|
||||
assert img.size == (image_size, image_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"return_text",
|
||||
[
|
||||
# Delay
|
||||
'{"error":"Model [model] is currently loading","estimated_time": [delay]}',
|
||||
'{"error":"Model [model] is currently loading"}', # No delay
|
||||
'{"error:}', # Bad JSON
|
||||
"", # Bad Image
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"image_model",
|
||||
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
||||
)
|
||||
@pytest.mark.parametrize("delay", [10, 0])
|
||||
def test_huggingface_fail_request_with_delay(
|
||||
huggingface_image_gen_component: ImageGeneratorComponent,
|
||||
image_size,
|
||||
image_model,
|
||||
return_text,
|
||||
delay,
|
||||
):
|
||||
return_text = return_text.replace("[model]", image_model).replace(
|
||||
"[delay]", str(delay)
|
||||
)
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
if return_text == "":
|
||||
# Test bad image
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.ok = True
|
||||
mock_post.return_value.content = b"bad image"
|
||||
else:
|
||||
# Test delay and bad json
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = return_text
|
||||
|
||||
huggingface_image_gen_component.config.huggingface_image_model = image_model
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Verify request fails.
|
||||
result = huggingface_image_gen_component.generate_image(prompt, image_size)
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was called with delay if delay is in return_text
|
||||
if "estimated_time" in return_text:
|
||||
mock_sleep.assert_called_with(delay)
|
||||
else:
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_no_delay(
|
||||
mocker, huggingface_image_gen_component: ImageGeneratorComponent
|
||||
):
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = (
|
||||
'{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
|
||||
)
|
||||
|
||||
# Mock time.sleep
|
||||
mock_sleep = mocker.patch("time.sleep")
|
||||
|
||||
result = huggingface_image_gen_component.generate_image(
|
||||
"astronaut riding a horse", 512
|
||||
)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was not called.
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_json(
|
||||
mocker, huggingface_image_gen_component: ImageGeneratorComponent
|
||||
):
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = '{"error:}'
|
||||
|
||||
# Mock time.sleep
|
||||
mock_sleep = mocker.patch("time.sleep")
|
||||
|
||||
result = huggingface_image_gen_component.generate_image(
|
||||
"astronaut riding a horse", 512
|
||||
)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was not called.
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_image(
|
||||
mocker, huggingface_image_gen_component: ImageGeneratorComponent
|
||||
):
|
||||
# Mock requests.post
|
||||
mock_post = mocker.patch("requests.post")
|
||||
mock_post.return_value.status_code = 200
|
||||
|
||||
result = huggingface_image_gen_component.generate_image(
|
||||
"astronaut riding a horse", 512
|
||||
)
|
||||
|
||||
assert result == "Error creating image."
|
||||
152
forge/forge/components/web/test_search.py
Normal file
152
forge/forge/components/web/test_search.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from googleapiclient.errors import HttpError
|
||||
from httplib2 import Response
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.utils.exceptions import ConfigurationError
|
||||
|
||||
from . import WebSearchComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_search_component():
|
||||
component = WebSearchComponent()
|
||||
if component.config.google_api_key is None:
|
||||
component.config.google_api_key = SecretStr("test")
|
||||
if component.config.google_custom_search_engine_id is None:
|
||||
component.config.google_custom_search_engine_id = SecretStr("test")
|
||||
return component
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, expected_output",
|
||||
[("test", "test"), (["test1", "test2"], '["test1", "test2"]')],
|
||||
)
|
||||
@pytest.fixture
|
||||
def test_safe_google_results(
|
||||
query, expected_output, web_search_component: WebSearchComponent
|
||||
):
|
||||
result = web_search_component.safe_google_results(query)
|
||||
assert isinstance(result, str)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_safe_google_results_invalid_input(web_search_component: WebSearchComponent):
|
||||
with pytest.raises(AttributeError):
|
||||
web_search_component.safe_google_results(123) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, num_results, expected_output_parts, return_value",
|
||||
[
|
||||
(
|
||||
"test",
|
||||
1,
|
||||
("Result 1", "https://example.com/result1"),
|
||||
[{"title": "Result 1", "href": "https://example.com/result1"}],
|
||||
),
|
||||
("", 1, (), []),
|
||||
("no results", 1, (), []),
|
||||
],
|
||||
)
|
||||
def test_google_search(
|
||||
query,
|
||||
num_results,
|
||||
expected_output_parts,
|
||||
return_value,
|
||||
mocker,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
mock_ddg = mocker.Mock()
|
||||
mock_ddg.return_value = return_value
|
||||
|
||||
mocker.patch("forge.components.web.search.DDGS.text", mock_ddg)
|
||||
actual_output = web_search_component.web_search(query, num_results=num_results)
|
||||
for o in expected_output_parts:
|
||||
assert o in actual_output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleapiclient(mocker):
|
||||
mock_build = mocker.patch("googleapiclient.discovery.build")
|
||||
mock_service = mocker.Mock()
|
||||
mock_build.return_value = mock_service
|
||||
return mock_service.cse().list().execute().get
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, num_results, search_results, expected_output",
|
||||
[
|
||||
(
|
||||
"test",
|
||||
3,
|
||||
[
|
||||
{"link": "http://example.com/result1"},
|
||||
{"link": "http://example.com/result2"},
|
||||
{"link": "http://example.com/result3"},
|
||||
],
|
||||
[
|
||||
"http://example.com/result1",
|
||||
"http://example.com/result2",
|
||||
"http://example.com/result3",
|
||||
],
|
||||
),
|
||||
("", 3, [], []),
|
||||
],
|
||||
)
|
||||
def test_google_official_search(
|
||||
query,
|
||||
num_results,
|
||||
expected_output,
|
||||
search_results,
|
||||
mock_googleapiclient,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
mock_googleapiclient.return_value = search_results
|
||||
actual_output = web_search_component.google(query, num_results=num_results)
|
||||
assert actual_output == web_search_component.safe_google_results(expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, num_results, expected_error_type, http_code, error_msg",
|
||||
[
|
||||
(
|
||||
"invalid query",
|
||||
3,
|
||||
HttpError,
|
||||
400,
|
||||
"Invalid Value",
|
||||
),
|
||||
(
|
||||
"invalid API key",
|
||||
3,
|
||||
ConfigurationError,
|
||||
403,
|
||||
"invalid API key",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_google_official_search_errors(
|
||||
query,
|
||||
num_results,
|
||||
expected_error_type,
|
||||
mock_googleapiclient,
|
||||
http_code,
|
||||
error_msg,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
response_content = {
|
||||
"error": {"code": http_code, "message": error_msg, "reason": "backendError"}
|
||||
}
|
||||
error = HttpError(
|
||||
resp=Response({"status": http_code, "reason": error_msg}),
|
||||
content=str.encode(json.dumps(response_content)),
|
||||
uri="https://www.googleapis.com/customsearch/v1?q=invalid+query&cx",
|
||||
)
|
||||
|
||||
mock_googleapiclient.side_effect = error
|
||||
with pytest.raises(expected_error_type):
|
||||
web_search_component.google(query, num_results=num_results)
|
||||
26
forge/forge/components/web/test_selenium.py
Normal file
26
forge/forge/components/web/test_selenium.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from forge.llm.providers.multi import MultiProvider
|
||||
|
||||
from . import BrowsingError, WebSeleniumComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_selenium_component(app_data_dir: Path):
|
||||
return WebSeleniumComponent(MultiProvider(), app_data_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_browse_website_nonexistent_url(
|
||||
web_selenium_component: WebSeleniumComponent,
|
||||
):
|
||||
url = "https://auto-gpt-thinks-this-website-does-not-exist.com"
|
||||
question = "How to execute a barrel roll"
|
||||
|
||||
with pytest.raises(BrowsingError, match="NAME_NOT_RESOLVED") as raised:
|
||||
await web_selenium_component.read_webpage(url=url, question=question)
|
||||
|
||||
# Sanity check that the response is not too long
|
||||
assert len(raised.exconly()) < 200
|
||||
@@ -1,8 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_workspace(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
202
forge/forge/file_storage/test_gcs_file_storage.py
Normal file
202
forge/forge/file_storage/test_gcs_file_storage.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
|
||||
try:
|
||||
storage.Client()
|
||||
except GoogleAuthError:
|
||||
pytest.skip("Google Cloud Authentication not configured", allow_module_level=True)
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_root() -> Path:
|
||||
return Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage_uninitialized(gcs_bucket_name: str, gcs_root: Path):
|
||||
os.environ["STORAGE_BUCKET"] = gcs_bucket_name
|
||||
storage_config = GCSFileStorageConfiguration.from_env()
|
||||
storage_config.root = gcs_root
|
||||
storage = GCSFileStorage(storage_config)
|
||||
yield storage # type: ignore
|
||||
del os.environ["STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(gcs_bucket_name: str, gcs_storage_uninitialized: GCSFileStorage):
|
||||
gcs = gcs_storage_uninitialized._gcs
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(NotFound):
|
||||
gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
gcs_storage_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
bucket = gcs.get_bucket(gcs_bucket_name)
|
||||
|
||||
# clean up
|
||||
bucket.delete(force=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gcs_storage(gcs_storage_uninitialized: GCSFileStorage):
|
||||
(gcs_storage := gcs_storage_uninitialized).initialize()
|
||||
yield gcs_storage # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
gcs_storage._bucket.delete(force=True)
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
gcs_storage: GCSFileStorage,
|
||||
gcs_bucket_name: str,
|
||||
):
|
||||
assert gcs_storage._bucket.name == gcs_bucket_name
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test_file_4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gcs_storage_with_files(gcs_storage: GCSFileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
gcs_storage._bucket.blob(
|
||||
str(gcs_storage.get_path(file_name))
|
||||
).upload_from_string(file_content)
|
||||
yield gcs_storage # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = gcs_storage_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
gcs_storage_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(gcs_storage_with_files: GCSFileStorage):
|
||||
# List at root level
|
||||
assert (
|
||||
files := gcs_storage_with_files.list_files()
|
||||
) == gcs_storage_with_files.list_files()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := gcs_storage_with_files.list_files(NESTED_DIR)
|
||||
) == gcs_storage_with_files.list_files(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
def test_list_folders(gcs_storage_with_files: GCSFileStorage):
|
||||
# List recursive
|
||||
folders = gcs_storage_with_files.list_folders(recursive=True)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {
|
||||
Path("existing"),
|
||||
Path("existing/test"),
|
||||
Path("existing/test/dir"),
|
||||
}
|
||||
# List non-recursive
|
||||
folders = gcs_storage_with_files.list_folders(recursive=False)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {Path("existing")}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(gcs_storage: GCSFileStorage):
|
||||
await gcs_storage.write_file("test_file", "test_content")
|
||||
assert gcs_storage.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await gcs_storage_with_files.write_file(file_name, "new content")
|
||||
assert gcs_storage_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
gcs_storage_with_files.delete_file(file_to_delete)
|
||||
assert not gcs_storage_with_files.exists(file_to_delete)
|
||||
|
||||
|
||||
def test_exists(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
assert gcs_storage_with_files.exists(file_name)
|
||||
|
||||
assert not gcs_storage_with_files.exists("non_existent_file")
|
||||
|
||||
|
||||
def test_rename_file(gcs_storage_with_files: GCSFileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
new_name = str(file_name) + "_renamed"
|
||||
gcs_storage_with_files.rename(file_name, new_name)
|
||||
assert gcs_storage_with_files.exists(new_name)
|
||||
assert not gcs_storage_with_files.exists(file_name)
|
||||
|
||||
|
||||
def test_rename_dir(gcs_storage_with_files: GCSFileStorage):
|
||||
gcs_storage_with_files.rename(NESTED_DIR, "existing/test/dir_renamed")
|
||||
assert gcs_storage_with_files.exists("existing/test/dir_renamed")
|
||||
assert not gcs_storage_with_files.exists(NESTED_DIR)
|
||||
|
||||
|
||||
def test_clone(gcs_storage_with_files: GCSFileStorage, gcs_root: Path):
|
||||
cloned = gcs_storage_with_files.clone_with_subroot("existing/test")
|
||||
assert cloned.root == gcs_root / Path("existing/test")
|
||||
assert cloned._bucket.name == gcs_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: GCSFileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: GCSFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
211
forge/forge/file_storage/test_local_file_storage.py
Normal file
211
forge/forge/file_storage/test_local_file_storage.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from .local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
_ACCESSIBLE_PATHS = [
|
||||
Path("."),
|
||||
Path("test_file.txt"),
|
||||
Path("test_folder"),
|
||||
Path("test_folder/test_file.txt"),
|
||||
Path("test_folder/.."),
|
||||
Path("test_folder/../test_file.txt"),
|
||||
Path("test_folder/../test_folder"),
|
||||
Path("test_folder/../test_folder/test_file.txt"),
|
||||
]
|
||||
|
||||
_INACCESSIBLE_PATHS = (
|
||||
[
|
||||
# Takes us out of the workspace
|
||||
Path(".."),
|
||||
Path("../test_file.txt"),
|
||||
Path("../not_auto_gpt_workspace"),
|
||||
Path("../not_auto_gpt_workspace/test_file.txt"),
|
||||
Path("test_folder/../.."),
|
||||
Path("test_folder/../../test_file.txt"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace"),
|
||||
Path("test_folder/../../not_auto_gpt_workspace/test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Contains null byte
|
||||
Path("\0"),
|
||||
Path("\0test_file.txt"),
|
||||
Path("test_folder/\0"),
|
||||
Path("test_folder/\0test_file.txt"),
|
||||
]
|
||||
+ [
|
||||
# Absolute paths
|
||||
Path("/"),
|
||||
Path("/test_file.txt"),
|
||||
Path("/home"),
|
||||
]
|
||||
)
|
||||
|
||||
_TEST_FILES = [
|
||||
Path("test_file.txt"),
|
||||
Path("dir/test_file.txt"),
|
||||
Path("dir/test_file2.txt"),
|
||||
Path("dir/sub_dir/test_file.txt"),
|
||||
]
|
||||
|
||||
_TEST_DIRS = [
|
||||
Path("dir"),
|
||||
Path("dir/sub_dir"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_root(tmp_path):
|
||||
return tmp_path / "data"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage(storage_root):
|
||||
return LocalFileStorage(
|
||||
FileStorageConfiguration(root=storage_root, restrict_to_root=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def content():
|
||||
return "test content"
|
||||
|
||||
|
||||
@pytest.fixture(params=_ACCESSIBLE_PATHS)
|
||||
def accessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_INACCESSIBLE_PATHS)
|
||||
def inaccessible_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_TEST_FILES)
|
||||
def file_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_file(file_path: Path, content: str, storage: LocalFileStorage):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
file = storage.open_file(file_path)
|
||||
assert file.read() == content
|
||||
file.close()
|
||||
storage.delete_file(file_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(content: str, storage: LocalFileStorage):
|
||||
await storage.write_file("test_file.txt", content)
|
||||
assert storage.read_file("test_file.txt") == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files(content: str, storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("test_file.txt", content)
|
||||
await storage.write_file("dir/test_file.txt", content)
|
||||
await storage.write_file("dir/test_file2.txt", content)
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", content)
|
||||
files = storage.list_files()
|
||||
assert Path("test_file.txt") in files
|
||||
assert Path("dir/test_file.txt") in files
|
||||
assert Path("dir/test_file2.txt") in files
|
||||
assert Path("dir/sub_dir/test_file.txt") in files
|
||||
storage.delete_file("test_file.txt")
|
||||
storage.delete_file("dir/test_file.txt")
|
||||
storage.delete_file("dir/test_file2.txt")
|
||||
storage.delete_file("dir/sub_dir/test_file.txt")
|
||||
storage.delete_dir("dir/sub_dir")
|
||||
storage.delete_dir("dir")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders(content: str, storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", content)
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", content)
|
||||
folders = storage.list_folders(recursive=False)
|
||||
folders_recursive = storage.list_folders(recursive=True)
|
||||
assert Path("dir") in folders
|
||||
assert Path("dir/sub_dir") not in folders
|
||||
assert Path("dir") in folders_recursive
|
||||
assert Path("dir/sub_dir") in folders_recursive
|
||||
storage.delete_file("dir/test_file.txt")
|
||||
storage.delete_file("dir/sub_dir/test_file.txt")
|
||||
storage.delete_dir("dir/sub_dir")
|
||||
storage.delete_dir("dir")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_delete_file(
|
||||
file_path: Path, content: str, storage: LocalFileStorage
|
||||
):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
assert storage.exists(file_path)
|
||||
storage.delete_file(file_path)
|
||||
assert not storage.exists(file_path)
|
||||
|
||||
|
||||
@pytest.fixture(params=_TEST_DIRS)
|
||||
def test_make_delete_dir(request, storage: LocalFileStorage):
|
||||
storage.make_dir(request)
|
||||
assert storage.exists(request)
|
||||
storage.delete_dir(request)
|
||||
assert not storage.exists(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rename(file_path: Path, content: str, storage: LocalFileStorage):
|
||||
if file_path.parent:
|
||||
storage.make_dir(file_path.parent)
|
||||
await storage.write_file(file_path, content)
|
||||
assert storage.exists(file_path)
|
||||
storage.rename(file_path, Path(str(file_path) + "_renamed"))
|
||||
assert not storage.exists(file_path)
|
||||
assert storage.exists(Path(str(file_path) + "_renamed"))
|
||||
|
||||
|
||||
def test_clone_with_subroot(storage: LocalFileStorage):
|
||||
subroot = storage.clone_with_subroot("dir")
|
||||
assert subroot.root == storage.root / "dir"
|
||||
|
||||
|
||||
def test_get_path_accessible(accessible_path: Path, storage: LocalFileStorage):
|
||||
full_path = storage.get_path(accessible_path)
|
||||
assert full_path.is_absolute()
|
||||
assert full_path.is_relative_to(storage.root)
|
||||
|
||||
|
||||
def test_get_path_inaccessible(inaccessible_path: Path, storage: LocalFileStorage):
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_path(inaccessible_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: LocalFileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: LocalFileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
196
forge/forge/file_storage/test_s3_file_storage.py
Normal file
196
forge/forge/file_storage/test_s3_file_storage.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from .s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
if not (os.getenv("S3_ENDPOINT_URL") and os.getenv("AWS_ACCESS_KEY_ID")):
|
||||
pytest.skip("S3 environment variables are not set", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_bucket_name() -> str:
|
||||
return f"test-bucket-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_root() -> Path:
|
||||
return Path("/workspaces/AutoGPT-some-unique-task-id")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage_uninitialized(s3_bucket_name: str, s3_root: Path):
|
||||
os.environ["STORAGE_BUCKET"] = s3_bucket_name
|
||||
storage_config = S3FileStorageConfiguration.from_env()
|
||||
storage_config.root = s3_root
|
||||
storage = S3FileStorage(storage_config)
|
||||
yield storage # type: ignore
|
||||
del os.environ["STORAGE_BUCKET"]
|
||||
|
||||
|
||||
def test_initialize(s3_bucket_name: str, s3_storage_uninitialized: S3FileStorage):
|
||||
s3 = s3_storage_uninitialized._s3
|
||||
|
||||
# test that the bucket doesn't exist yet
|
||||
with pytest.raises(ClientError):
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||
|
||||
s3_storage_uninitialized.initialize()
|
||||
|
||||
# test that the bucket has been created
|
||||
s3.meta.client.head_bucket(Bucket=s3_bucket_name) # pyright: ignore
|
||||
# FIXME: remove the "pyright: ignore" comments after moving this test file to forge
|
||||
|
||||
|
||||
def test_workspace_bucket_name(
|
||||
s3_storage: S3FileStorage,
|
||||
s3_bucket_name: str,
|
||||
):
|
||||
assert s3_storage._bucket.name == s3_bucket_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage(s3_storage_uninitialized: S3FileStorage):
|
||||
(s3_storage := s3_storage_uninitialized).initialize()
|
||||
yield s3_storage # type: ignore
|
||||
|
||||
# Empty & delete the test bucket
|
||||
s3_storage._bucket.objects.all().delete()
|
||||
s3_storage._bucket.delete()
|
||||
|
||||
|
||||
NESTED_DIR = "existing/test/dir"
|
||||
TEST_FILES: list[tuple[str | Path, str]] = [
|
||||
("existing_test_file_1", "test content 1"),
|
||||
("existing_test_file_2.txt", "test content 2"),
|
||||
(Path("existing_test_file_3"), "test content 3"),
|
||||
(Path(f"{NESTED_DIR}/test_file_4"), "test content 4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def s3_storage_with_files(s3_storage: S3FileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
s3_storage._bucket.Object(str(s3_storage.get_path(file_name))).put(
|
||||
Body=file_content
|
||||
)
|
||||
yield s3_storage # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, file_content in TEST_FILES:
|
||||
content = s3_storage_with_files.read_file(file_name)
|
||||
assert content == file_content
|
||||
|
||||
with pytest.raises(ClientError):
|
||||
s3_storage_with_files.read_file("non_existent_file")
|
||||
|
||||
|
||||
def test_list_files(s3_storage_with_files: S3FileStorage):
|
||||
# List at root level
|
||||
assert (
|
||||
files := s3_storage_with_files.list_files()
|
||||
) == s3_storage_with_files.list_files()
|
||||
assert len(files) > 0
|
||||
assert set(files) == set(Path(file_name) for file_name, _ in TEST_FILES)
|
||||
|
||||
# List at nested path
|
||||
assert (
|
||||
nested_files := s3_storage_with_files.list_files(NESTED_DIR)
|
||||
) == s3_storage_with_files.list_files(NESTED_DIR)
|
||||
assert len(nested_files) > 0
|
||||
assert set(nested_files) == set(
|
||||
p.relative_to(NESTED_DIR)
|
||||
for file_name, _ in TEST_FILES
|
||||
if (p := Path(file_name)).is_relative_to(NESTED_DIR)
|
||||
)
|
||||
|
||||
|
||||
def test_list_folders(s3_storage_with_files: S3FileStorage):
|
||||
# List recursive
|
||||
folders = s3_storage_with_files.list_folders(recursive=True)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {
|
||||
Path("existing"),
|
||||
Path("existing/test"),
|
||||
Path("existing/test/dir"),
|
||||
}
|
||||
# List non-recursive
|
||||
folders = s3_storage_with_files.list_folders(recursive=False)
|
||||
assert len(folders) > 0
|
||||
assert set(folders) == {Path("existing")}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_read_file(s3_storage: S3FileStorage):
|
||||
await s3_storage.write_file("test_file", "test_content")
|
||||
assert s3_storage.read_file("test_file") == "test_content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
await s3_storage_with_files.write_file(file_name, "new content")
|
||||
assert s3_storage_with_files.read_file(file_name) == "new content"
|
||||
|
||||
|
||||
def test_delete_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_to_delete, _ in TEST_FILES:
|
||||
s3_storage_with_files.delete_file(file_to_delete)
|
||||
with pytest.raises(ClientError):
|
||||
s3_storage_with_files.read_file(file_to_delete)
|
||||
|
||||
|
||||
def test_exists(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
assert s3_storage_with_files.exists(file_name)
|
||||
|
||||
assert not s3_storage_with_files.exists("non_existent_file")
|
||||
|
||||
|
||||
def test_rename_file(s3_storage_with_files: S3FileStorage):
|
||||
for file_name, _ in TEST_FILES:
|
||||
new_name = str(file_name) + "_renamed"
|
||||
s3_storage_with_files.rename(file_name, new_name)
|
||||
assert s3_storage_with_files.exists(new_name)
|
||||
assert not s3_storage_with_files.exists(file_name)
|
||||
|
||||
|
||||
def test_rename_dir(s3_storage_with_files: S3FileStorage):
|
||||
s3_storage_with_files.rename(NESTED_DIR, "existing/test/dir_renamed")
|
||||
assert s3_storage_with_files.exists("existing/test/dir_renamed")
|
||||
assert not s3_storage_with_files.exists(NESTED_DIR)
|
||||
|
||||
|
||||
def test_clone(s3_storage_with_files: S3FileStorage, s3_root: Path):
|
||||
cloned = s3_storage_with_files.clone_with_subroot("existing/test")
|
||||
assert cloned.root == s3_root / Path("existing/test")
|
||||
assert cloned._bucket.name == s3_storage_with_files._bucket.name
|
||||
assert cloned.exists("dir")
|
||||
assert cloned.exists("dir/test_file_4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_file(storage: S3FileStorage):
|
||||
await storage.write_file("test_file.txt", "test content")
|
||||
storage.copy("test_file.txt", "test_file_copy.txt")
|
||||
storage.make_dir("dir")
|
||||
storage.copy("test_file.txt", "dir/test_file_copy.txt")
|
||||
assert storage.read_file("test_file_copy.txt") == "test content"
|
||||
assert storage.read_file("dir/test_file_copy.txt") == "test content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_dir(storage: S3FileStorage):
|
||||
storage.make_dir("dir")
|
||||
storage.make_dir("dir/sub_dir")
|
||||
await storage.write_file("dir/test_file.txt", "test content")
|
||||
await storage.write_file("dir/sub_dir/test_file.txt", "test content")
|
||||
storage.copy("dir", "dir_copy")
|
||||
assert storage.read_file("dir_copy/test_file.txt") == "test content"
|
||||
assert storage.read_file("dir_copy/sub_dir/test_file.txt") == "test content"
|
||||
93
forge/forge/json/test_parsing.py
Normal file
93
forge/forge/json/test_parsing.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from .parsing import json_loads
|
||||
|
||||
_JSON_FIXABLE: list[tuple[str, str]] = [
|
||||
# Missing comma
|
||||
('{"name": "John Doe" "age": 30,}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2 3]", "[1, 2, 3]"),
|
||||
# Trailing comma
|
||||
('{"name": "John Doe", "age": 30,}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2, 3,]", "[1, 2, 3]"),
|
||||
# Extra comma in object
|
||||
('{"name": "John Doe",, "age": 30}', '{"name": "John Doe", "age": 30}'),
|
||||
# Extra newlines
|
||||
('{"name": "John Doe",\n"age": 30}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2,\n3]", "[1, 2, 3]"),
|
||||
# Missing closing brace or bracket
|
||||
('{"name": "John Doe", "age": 30', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2, 3", "[1, 2, 3]"),
|
||||
# Different numerals
|
||||
("[+1, ---2, .5, +-4.5, 123.]", "[1, -2, 0.5, -4.5, 123]"),
|
||||
('{"bin": 0b1001, "hex": 0x1A, "oct": 0o17}', '{"bin": 9, "hex": 26, "oct": 15}'),
|
||||
# Broken array
|
||||
(
|
||||
'[1, 2 3, "yes" true, false null, 25, {"obj": "var"}',
|
||||
'[1, 2, 3, "yes", true, false, null, 25, {"obj": "var"}]',
|
||||
),
|
||||
# Codeblock
|
||||
(
|
||||
'```json\n{"name": "John Doe", "age": 30}\n```',
|
||||
'{"name": "John Doe", "age": 30}',
|
||||
),
|
||||
# Multiple problems
|
||||
(
|
||||
'{"name":"John Doe" "age": 30\n "empty": "","address": '
|
||||
"// random comment\n"
|
||||
'{"city": "New York", "state": "NY"},'
|
||||
'"skills": ["Python" "C++", "Java",""],',
|
||||
'{"name": "John Doe", "age": 30, "empty": "", "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java", ""]}',
|
||||
),
|
||||
# All good
|
||||
(
|
||||
'{"name": "John Doe", "age": 30, "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java"]}',
|
||||
'{"name": "John Doe", "age": 30, "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java"]}',
|
||||
),
|
||||
("true", "true"),
|
||||
("false", "false"),
|
||||
("null", "null"),
|
||||
("123.5", "123.5"),
|
||||
('"Hello, World!"', '"Hello, World!"'),
|
||||
("{}", "{}"),
|
||||
("[]", "[]"),
|
||||
]
|
||||
|
||||
_JSON_UNFIXABLE: list[tuple[str, str]] = [
|
||||
# Broken booleans and null
|
||||
("[TRUE, False, NULL]", "[true, false, null]"),
|
||||
# Missing values in array
|
||||
("[1, , 3]", "[1, 3]"),
|
||||
# Leading zeros (are treated as octal)
|
||||
("[0023, 015]", "[23, 15]"),
|
||||
# Missing quotes
|
||||
('{"name": John Doe}', '{"name": "John Doe"}'),
|
||||
# Missing opening braces or bracket
|
||||
('"name": "John Doe"}', '{"name": "John Doe"}'),
|
||||
("1, 2, 3]", "[1, 2, 3]"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=_JSON_FIXABLE)
|
||||
def fixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_JSON_UNFIXABLE)
|
||||
def unfixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
|
||||
return request.param
|
||||
|
||||
|
||||
def test_json_loads_fixable(fixable_json: tuple[str, str]):
|
||||
assert json_loads(fixable_json[0]) == json.loads(fixable_json[1])
|
||||
|
||||
|
||||
def test_json_loads_unfixable(unfixable_json: tuple[str, str]):
|
||||
assert json_loads(unfixable_json[0]) != json.loads(unfixable_json[1])
|
||||
36
forge/forge/logging/test_utils.py
Normal file
36
forge/forge/logging/test_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_text, clean_text",
|
||||
[
|
||||
(
|
||||
"COMMAND = \x1b[36mbrowse_website\x1b[0m "
|
||||
"ARGUMENTS = \x1b[36m{'url': 'https://www.google.com',"
|
||||
" 'question': 'What is the capital of France?'}\x1b[0m",
|
||||
"COMMAND = browse_website "
|
||||
"ARGUMENTS = {'url': 'https://www.google.com',"
|
||||
" 'question': 'What is the capital of France?'}",
|
||||
),
|
||||
(
|
||||
"{'Schaue dir meine Projekte auf github () an, als auch meine Webseiten': "
|
||||
"'https://github.com/Significant-Gravitas/AutoGPT,"
|
||||
" https://discord.gg/autogpt und https://twitter.com/Auto_GPT'}",
|
||||
"{'Schaue dir meine Projekte auf github () an, als auch meine Webseiten': "
|
||||
"'https://github.com/Significant-Gravitas/AutoGPT,"
|
||||
" https://discord.gg/autogpt und https://twitter.com/Auto_GPT'}",
|
||||
),
|
||||
("", ""),
|
||||
("hello", "hello"),
|
||||
("hello\x1B[31m world", "hello world"),
|
||||
("\x1B[36mHello,\x1B[32m World!", "Hello, World!"),
|
||||
(
|
||||
"\x1B[1m\x1B[31mError:\x1B[0m\x1B[31m file not found",
|
||||
"Error: file not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_remove_color_codes(raw_text, clean_text):
|
||||
assert remove_color_codes(raw_text) == clean_text
|
||||
0
forge/forge/utils/__init__.py
Normal file
0
forge/forge/utils/__init__.py
Normal file
167
forge/forge/utils/test_file_operations.py
Normal file
167
forge/forge/utils/test_file_operations.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import docx
|
||||
import pytest
|
||||
import yaml
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from .file_operations import decode_textual_file, is_file_binary_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
plain_text_str = "Hello, world!"
|
||||
|
||||
|
||||
def mock_text_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||
f.write(plain_text_str)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_csv_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
|
||||
f.write(plain_text_str)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_pdf_file():
|
||||
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pdf") as f:
|
||||
# Create a new PDF and add a page with the text plain_text_str
|
||||
# Write the PDF header
|
||||
f.write(b"%PDF-1.7\n")
|
||||
# Write the document catalog
|
||||
f.write(b"1 0 obj\n")
|
||||
f.write(b"<< /Type /Catalog /Pages 2 0 R >>\n")
|
||||
f.write(b"endobj\n")
|
||||
# Write the page object
|
||||
f.write(b"2 0 obj\n")
|
||||
f.write(
|
||||
b"<< /Type /Page /Parent 1 0 R /Resources << /Font << /F1 3 0 R >> >> "
|
||||
b"/MediaBox [0 0 612 792] /Contents 4 0 R >>\n"
|
||||
)
|
||||
f.write(b"endobj\n")
|
||||
# Write the font object
|
||||
f.write(b"3 0 obj\n")
|
||||
f.write(
|
||||
b"<< /Type /Font /Subtype /Type1 /Name /F1 /BaseFont /Helvetica-Bold >>\n"
|
||||
)
|
||||
f.write(b"endobj\n")
|
||||
# Write the page contents object
|
||||
f.write(b"4 0 obj\n")
|
||||
f.write(b"<< /Length 25 >>\n")
|
||||
f.write(b"stream\n")
|
||||
f.write(b"BT\n/F1 12 Tf\n72 720 Td\n(Hello, world!) Tj\nET\n")
|
||||
f.write(b"endstream\n")
|
||||
f.write(b"endobj\n")
|
||||
# Write the cross-reference table
|
||||
f.write(b"xref\n")
|
||||
f.write(b"0 5\n")
|
||||
f.write(b"0000000000 65535 f \n")
|
||||
f.write(b"0000000017 00000 n \n")
|
||||
f.write(b"0000000073 00000 n \n")
|
||||
f.write(b"0000000123 00000 n \n")
|
||||
f.write(b"0000000271 00000 n \n")
|
||||
f.write(b"trailer\n")
|
||||
f.write(b"<< /Size 5 /Root 1 0 R >>\n")
|
||||
f.write(b"startxref\n")
|
||||
f.write(b"380\n")
|
||||
f.write(b"%%EOF\n")
|
||||
f.write(b"\x00")
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_docx_file():
|
||||
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".docx") as f:
|
||||
document = docx.Document()
|
||||
document.add_paragraph(plain_text_str)
|
||||
document.save(f.name)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_json_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
json.dump({"text": plain_text_str}, f)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_xml_file():
|
||||
root = ElementTree.Element("text")
|
||||
root.text = plain_text_str
|
||||
tree = ElementTree.ElementTree(root)
|
||||
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".xml") as f:
|
||||
tree.write(f)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_yaml_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f:
|
||||
yaml.dump({"text": plain_text_str}, f)
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_html_file():
|
||||
html = BeautifulSoup(
|
||||
"<html>"
|
||||
"<head><title>This is a test</title></head>"
|
||||
f"<body><p>{plain_text_str}</p></body>"
|
||||
"</html>",
|
||||
"html.parser",
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".html") as f:
|
||||
f.write(str(html))
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_md_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".md") as f:
|
||||
f.write(f"# {plain_text_str}!\n")
|
||||
return f.name
|
||||
|
||||
|
||||
def mock_latex_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".tex") as f:
|
||||
latex_str = (
|
||||
r"\documentclass{article}"
|
||||
r"\begin{document}"
|
||||
f"{plain_text_str}"
|
||||
r"\end{document}"
|
||||
)
|
||||
f.write(latex_str)
|
||||
return f.name
|
||||
|
||||
|
||||
respective_file_creation_functions = {
|
||||
".txt": mock_text_file,
|
||||
".csv": mock_csv_file,
|
||||
".pdf": mock_pdf_file,
|
||||
".docx": mock_docx_file,
|
||||
".json": mock_json_file,
|
||||
".xml": mock_xml_file,
|
||||
".yaml": mock_yaml_file,
|
||||
".html": mock_html_file,
|
||||
".md": mock_md_file,
|
||||
".tex": mock_latex_file,
|
||||
}
|
||||
binary_files_extensions = [".pdf", ".docx"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"file_extension, c_file_creator",
|
||||
respective_file_creation_functions.items(),
|
||||
)
|
||||
def test_parsers(file_extension, c_file_creator):
|
||||
created_file_path = Path(c_file_creator())
|
||||
with open(created_file_path, "rb") as file:
|
||||
loaded_text = decode_textual_file(file, os.path.splitext(file.name)[1], logger)
|
||||
|
||||
assert plain_text_str in loaded_text
|
||||
|
||||
should_be_binary = file_extension in binary_files_extensions
|
||||
assert should_be_binary == is_file_binary_fn(file)
|
||||
|
||||
created_file_path.unlink() # cleanup
|
||||
157
forge/forge/utils/test_url_validator.py
Normal file
157
forge/forge/utils/test_url_validator.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from .url_validator import validate_url
|
||||
|
||||
|
||||
@validate_url
|
||||
def dummy_method(url):
|
||||
return url
|
||||
|
||||
|
||||
successful_test_data = (
|
||||
("https://google.com/search?query=abc"),
|
||||
("https://google.com/search?query=abc&p=123"),
|
||||
("http://google.com/"),
|
||||
("http://a.lot.of.domain.net/param1/param2"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", successful_test_data)
|
||||
def test_url_validation_succeeds(url):
|
||||
assert dummy_method(url) == url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,expected_error",
|
||||
[
|
||||
("htt://example.com", "Invalid URL format"),
|
||||
("httppp://example.com", "Invalid URL format"),
|
||||
(" https://example.com", "Invalid URL format"),
|
||||
("http://?query=q", "Missing Scheme or Network location"),
|
||||
],
|
||||
)
|
||||
def test_url_validation_fails_invalid_url(url, expected_error):
|
||||
with raises(ValueError, match=expected_error):
|
||||
dummy_method(url)
|
||||
|
||||
|
||||
local_file = (
|
||||
("file://localhost"),
|
||||
("file://localhost/home/reinier/secrets.txt"),
|
||||
("file:///home/reinier/secrets.txt"),
|
||||
("file:///C:/Users/Reinier/secrets.txt"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", local_file)
|
||||
def test_url_validation_fails_local_path(url):
|
||||
with raises(ValueError):
|
||||
dummy_method(url)
|
||||
|
||||
|
||||
def test_happy_path_valid_url():
|
||||
"""
|
||||
Test that the function successfully validates a valid URL with `http://` or
|
||||
`https://` prefix.
|
||||
"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
assert test_func("https://www.google.com") == "https://www.google.com"
|
||||
assert test_func("http://www.google.com") == "http://www.google.com"
|
||||
|
||||
|
||||
def test_general_behavior_additional_path_parameters_query_string():
|
||||
"""
|
||||
Test that the function successfully validates a valid URL with additional path,
|
||||
parameters, and query string.
|
||||
"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
assert (
|
||||
test_func("https://www.google.com/search?q=python")
|
||||
== "https://www.google.com/search?q=python"
|
||||
)
|
||||
|
||||
|
||||
def test_edge_case_missing_scheme_or_network_location():
|
||||
"""
|
||||
Test that the function raises a ValueError if the URL is missing scheme or
|
||||
network location.
|
||||
"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func("www.google.com")
|
||||
|
||||
|
||||
def test_edge_case_local_file_access():
|
||||
"""Test that the function raises a ValueError if the URL has local file access"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func("file:///etc/passwd")
|
||||
|
||||
|
||||
def test_general_behavior_sanitizes_url():
|
||||
"""Test that the function sanitizes the URL by removing unnecessary components"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
assert (
|
||||
test_func("https://www.google.com/search?q=python#top")
|
||||
== "https://www.google.com/search?q=python"
|
||||
)
|
||||
|
||||
|
||||
def test_general_behavior_invalid_url_format():
|
||||
"""
|
||||
Test that the function raises a ValueError if the URL has an invalid format
|
||||
(e.g. missing slashes)
|
||||
"""
|
||||
|
||||
@validate_url
|
||||
def test_func(url):
|
||||
return url
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func("https:www.google.com")
|
||||
|
||||
|
||||
def test_url_with_special_chars():
|
||||
"""
|
||||
Tests that the function can handle URLs that contain unusual but valid characters.
|
||||
"""
|
||||
url = "https://example.com/path%20with%20spaces"
|
||||
assert dummy_method(url) == url
|
||||
|
||||
|
||||
def test_extremely_long_url():
|
||||
"""
|
||||
Tests that the function raises a ValueError if the URL is over 2000 characters.
|
||||
"""
|
||||
url = "http://example.com/" + "a" * 2000
|
||||
with raises(ValueError, match="URL is too long"):
|
||||
dummy_method(url)
|
||||
|
||||
|
||||
def test_internationalized_url():
|
||||
"""
|
||||
Tests that the function can handle internationalized URLs with non-ASCII characters.
|
||||
"""
|
||||
url = "http://例子.测试"
|
||||
assert dummy_method(url) == url
|
||||
58
forge/poetry.lock
generated
58
forge/poetry.lock
generated
@@ -4986,6 +4986,42 @@ pytest = ">=4.6"
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.0"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
|
||||
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=6.2.5"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-recording"
|
||||
version = "0.13.1"
|
||||
description = "A pytest plugin that allows you recording of network interactions via VCR.py"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest_recording-0.13.1-py3-none-any.whl", hash = "sha256:e5c75feb2593eb4ed9362182c6640bfe19004204bf9a6082d62c91b5fdb50a3e"},
|
||||
{file = "pytest_recording-0.13.1.tar.gz", hash = "sha256:1265d679f39263f115968ec01c2a3bfed250170fd1b0d9e288970b2e4a13737a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=3.5.0"
|
||||
vcrpy = ">=2.0.1"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest-recording[tests]"]
|
||||
tests = ["pytest-httpbin", "pytest-mock", "requests", "werkzeug (==3.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
@@ -6504,6 +6540,26 @@ files = [
|
||||
docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
|
||||
test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "vcrpy"
|
||||
version = "5.1.0"
|
||||
description = "Automatically mock your HTTP interactions to simplify and speed up testing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
PyYAML = "*"
|
||||
wrapt = "*"
|
||||
yarl = "*"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/Significant-Gravitas/vcrpy.git"
|
||||
reference = "master"
|
||||
resolved_reference = "bfd15f9d06a516138b673cb481547f3352d9cc43"
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.25.0"
|
||||
@@ -7029,4 +7085,4 @@ benchmark = ["agbenchmark"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "5b8cca9caced2687d88fc61dc263054f15c49f2daa1560fa4d94fb5b38d461aa"
|
||||
content-hash = "7523abd672967cbe924f045a00bf519ee08c8537fdf2f2191d2928201497d7b7"
|
||||
|
||||
@@ -76,7 +76,10 @@ types-requests = "^2.31.0.2"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.23.3"
|
||||
pytest-cov = "^5.0.0"
|
||||
pytest-mock = "*"
|
||||
pytest-recording = "*"
|
||||
mock = "^5.1.0"
|
||||
vcrpy = { git = "https://github.com/Significant-Gravitas/vcrpy.git", rev = "master" }
|
||||
|
||||
|
||||
[build-system]
|
||||
@@ -101,3 +104,4 @@ pythonVersion = "3.10"
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["forge"]
|
||||
testpaths = ["forge", "tests"]
|
||||
markers = ["slow"]
|
||||
|
||||
81
forge/tests/vcr/__init__.py
Normal file
81
forge/tests/vcr/__init__.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
import os
|
||||
from hashlib import sha256
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from openai._models import FinalRequestOptions
|
||||
from openai._types import Omit
|
||||
from openai._utils import is_given
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .vcr_filter import (
|
||||
before_record_request,
|
||||
before_record_response,
|
||||
freeze_request_body,
|
||||
)
|
||||
|
||||
DEFAULT_RECORD_MODE = "new_episodes"
|
||||
BASE_VCR_CONFIG = {
|
||||
"before_record_request": before_record_request,
|
||||
"before_record_response": before_record_response,
|
||||
"match_on": ["method", "headers"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vcr_config(get_base_vcr_config):
|
||||
return get_base_vcr_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def get_base_vcr_config(request):
|
||||
record_mode = request.config.getoption("--record-mode", default="new_episodes")
|
||||
config = BASE_VCR_CONFIG
|
||||
|
||||
if record_mode is None:
|
||||
config["record_mode"] = DEFAULT_RECORD_MODE
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vcr_cassette_dir(request):
|
||||
test_name = os.path.splitext(request.node.name)[0]
|
||||
return os.path.join("tests/vcr_cassettes", test_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||
client = OpenAI()
|
||||
_prepare_options = client._prepare_options
|
||||
|
||||
def _patched_prepare_options(self, options: FinalRequestOptions):
|
||||
_prepare_options(options)
|
||||
|
||||
if not options.json_data:
|
||||
return
|
||||
|
||||
headers: dict[str, str | Omit] = (
|
||||
{**options.headers} if is_given(options.headers) else {}
|
||||
)
|
||||
options.headers = headers
|
||||
data = cast(dict, options.json_data)
|
||||
|
||||
logging.getLogger("cached_openai_client").debug(
|
||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||
)
|
||||
|
||||
# Add hash header for cheap & fast matching on cassette playback
|
||||
headers["X-Content-Hash"] = sha256(
|
||||
freeze_request_body(data), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
mocker.patch.object(
|
||||
client,
|
||||
"_prepare_options",
|
||||
new=_patched_prepare_options,
|
||||
)
|
||||
|
||||
return client
|
||||
110
forge/tests/vcr/vcr_filter.py
Normal file
110
forge/tests/vcr/vcr_filter.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from vcr.request import Request
|
||||
|
||||
HOSTNAMES_TO_CACHE: list[str] = [
|
||||
"api.openai.com",
|
||||
"localhost:50337",
|
||||
"duckduckgo.com",
|
||||
]
|
||||
|
||||
IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
|
||||
"Authorization",
|
||||
"Cookie",
|
||||
"OpenAI-Organization",
|
||||
"X-OpenAI-Client-User-Agent",
|
||||
"User-Agent",
|
||||
re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
|
||||
}
|
||||
|
||||
LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
|
||||
{
|
||||
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
|
||||
"replacement": "Tue Jan 1 00:00:00 2000",
|
||||
},
|
||||
{
|
||||
"regex": r"<selenium.webdriver.chrome.webdriver.WebDriver[^>]*>",
|
||||
"replacement": "",
|
||||
},
|
||||
]
|
||||
|
||||
OPENAI_URL = "api.openai.com"
|
||||
|
||||
|
||||
def before_record_request(request: Request) -> Request | None:
|
||||
if not should_cache_request(request):
|
||||
return None
|
||||
|
||||
request = filter_request_headers(request)
|
||||
request = freeze_request(request)
|
||||
return request
|
||||
|
||||
|
||||
def should_cache_request(request: Request) -> bool:
|
||||
return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
|
||||
|
||||
|
||||
def filter_request_headers(request: Request) -> Request:
|
||||
for header_name in list(request.headers):
|
||||
if any(
|
||||
(
|
||||
(type(ignore) is str and ignore.lower() == header_name.lower())
|
||||
or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
|
||||
)
|
||||
for ignore in IGNORE_REQUEST_HEADERS
|
||||
):
|
||||
del request.headers[header_name]
|
||||
return request
|
||||
|
||||
|
||||
def freeze_request(request: Request) -> Request:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
request.body = freeze_request_body(
|
||||
json.loads(
|
||||
request.body.getvalue()
|
||||
if isinstance(request.body, BytesIO)
|
||||
else request.body
|
||||
)
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def freeze_request_body(body: dict) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
|
||||
if "messages" not in body:
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
if "max_tokens" in body:
|
||||
del body["max_tokens"]
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message:
|
||||
if message["role"] == "system":
|
||||
message["content"] = replace_message_content(
|
||||
message["content"], LLM_MESSAGE_REPLACEMENTS
|
||||
)
|
||||
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
|
||||
def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
|
||||
for replacement in replacements:
|
||||
pattern = re.compile(replacement["regex"])
|
||||
content = pattern.sub(replacement["replacement"], content)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
|
||||
if "Transfer-Encoding" in response["headers"]:
|
||||
del response["headers"]["Transfer-Encoding"]
|
||||
return response
|
||||
1
forge/tests/vcr_cassettes
Submodule
1
forge/tests/vcr_cassettes
Submodule
Submodule forge/tests/vcr_cassettes added at e0f7f4a599
Reference in New Issue
Block a user