mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend): address PR review — scan bug, quota overwrite, test coverage
- Fix VirusScanError (ClamAV outage) returning 409 instead of 500 - Fix duplicate virus scan in CoPilot WriteWorkspaceFileTool - Fix overwrite quota miscalculation (subtract old file size from usage) - Add parametrized tests for get_workspace_storage_limit_bytes (all tiers) - Add test for quota rejection path in WorkspaceManager.write_file() - Add test for 80% warning log - Add test for VirusScanError → 500 mapping - Add test for get_storage_usage endpoint with tier-based limit - Extract projected_usage variable, remove redundant truthiness guard Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
|
||||
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -250,6 +250,8 @@ async def upload_file(
|
||||
)
|
||||
except VirusDetectedError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e)) from e
|
||||
except VirusScanError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e)) from e
|
||||
except ValueError as e:
|
||||
# write_file raises ValueError for path-conflict, size-limit, and
|
||||
# storage-quota cases; map each to its correct HTTP status.
|
||||
|
||||
@@ -582,3 +582,51 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
def test_upload_virus_scan_infrastructure_error_returns_500(mocker):
|
||||
"""VirusScanError (ClamAV outage) should return 500, not 409."""
|
||||
from backend.api.features.store.exceptions import VirusScanError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=VirusScanError("ClamAV connection refused"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_storage_usage_returns_tier_based_limit(mocker):
|
||||
"""get_storage_usage should return the user's tier-based limit, not a static config."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=100 * 1024 * 1024, # 100 MB used
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.count_workspace_files",
|
||||
return_value=5,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=1024 * 1024 * 1024, # 1 GB (PRO tier)
|
||||
)
|
||||
|
||||
response = client.get("/storage/usage")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit_bytes"] == 1024 * 1024 * 1024
|
||||
assert data["used_bytes"] == 100 * 1024 * 1024
|
||||
assert data["file_count"] == 5
|
||||
|
||||
@@ -9,6 +9,7 @@ from redis.exceptions import RedisError
|
||||
from .rate_limit import (
|
||||
DEFAULT_TIER,
|
||||
TIER_MULTIPLIERS,
|
||||
TIER_WORKSPACE_STORAGE_MB,
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
SubscriptionTier,
|
||||
@@ -23,6 +24,7 @@ from .rate_limit import (
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
get_workspace_storage_limit_bytes,
|
||||
increment_daily_reset_count,
|
||||
record_token_usage,
|
||||
release_reset_lock,
|
||||
@@ -1367,3 +1369,52 @@ class TestResetUserUsage:
|
||||
):
|
||||
with pytest.raises(RedisError):
|
||||
await reset_user_usage("user-1")
|
||||
|
||||
|
||||
class TestWorkspaceStorageLimits:
|
||||
"""Tests for tier-based workspace storage limits."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_mb",
|
||||
[
|
||||
(SubscriptionTier.FREE, 250),
|
||||
(SubscriptionTier.PRO, 1024),
|
||||
(SubscriptionTier.BUSINESS, 5 * 1024),
|
||||
(SubscriptionTier.ENTERPRISE, 15 * 1024),
|
||||
],
|
||||
)
|
||||
def test_tier_workspace_storage_mapping_covers_all_tiers(self, tier, expected_mb):
|
||||
"""Every tier has an explicit storage limit in the mapping."""
|
||||
assert tier in TIER_WORKSPACE_STORAGE_MB
|
||||
assert TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_bytes",
|
||||
[
|
||||
(SubscriptionTier.FREE, 250 * 1024 * 1024),
|
||||
(SubscriptionTier.PRO, 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.BUSINESS, 5 * 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.ENTERPRISE, 15 * 1024 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
async def test_get_workspace_storage_limit_bytes_per_tier(
|
||||
self, tier, expected_bytes
|
||||
):
|
||||
"""get_workspace_storage_limit_bytes returns correct bytes for each tier."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value=tier,
|
||||
):
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
assert result == expected_bytes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_storage_limit_bytes_defaults_to_free_on_unknown(self):
|
||||
"""Unknown tier falls back to FREE tier limit."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value="UNKNOWN_TIER",
|
||||
):
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
assert result == 250 * 1024 * 1024
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.copilot.context import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -837,7 +836,6 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await scan_content_safe(content_bytes, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content_bytes,
|
||||
|
||||
@@ -186,27 +186,7 @@ class WorkspaceManager:
|
||||
f"{Config().max_file_size_mb}MB limit"
|
||||
)
|
||||
|
||||
# Enforce per-user workspace storage quota (tier-based).
|
||||
storage_limit = await get_workspace_storage_limit_bytes(self.user_id)
|
||||
current_usage = await get_workspace_total_size(self.workspace_id)
|
||||
if current_usage + len(content) > storage_limit:
|
||||
used_pct = (current_usage / storage_limit) * 100
|
||||
raise ValueError(
|
||||
f"Storage limit exceeded: {current_usage:,} bytes used "
|
||||
f"of {storage_limit:,} bytes ({used_pct:.1f}%)"
|
||||
)
|
||||
if storage_limit and (current_usage + len(content)) / storage_limit >= 0.8:
|
||||
logger.warning(
|
||||
f"User {self.user_id} workspace storage at "
|
||||
f"{(current_usage + len(content)) / storage_limit * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit} bytes)"
|
||||
)
|
||||
|
||||
# Scan here — callers must NOT duplicate this scan.
|
||||
# WorkspaceManager owns virus scanning for all persisted files.
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Determine path with session scoping
|
||||
# Determine path with session scoping (needed before quota check for overwrites)
|
||||
if path is None:
|
||||
path = f"/{filename}"
|
||||
elif not path.startswith("/"):
|
||||
@@ -215,6 +195,35 @@ class WorkspaceManager:
|
||||
# Resolve path with session prefix
|
||||
path = self._resolve_path(path)
|
||||
|
||||
# Enforce per-user workspace storage quota (tier-based).
|
||||
# For overwrites, subtract the existing file's size so replacing a file
|
||||
# with a same-size or smaller file is not rejected near the cap.
|
||||
storage_limit = await get_workspace_storage_limit_bytes(self.user_id)
|
||||
current_usage = await get_workspace_total_size(self.workspace_id)
|
||||
if overwrite:
|
||||
db = workspace_db()
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
current_usage = max(0, current_usage - existing.size_bytes)
|
||||
|
||||
projected_usage = current_usage + len(content)
|
||||
if projected_usage > storage_limit:
|
||||
used_pct = (current_usage / storage_limit) * 100
|
||||
raise ValueError(
|
||||
f"Storage limit exceeded: {current_usage:,} bytes used "
|
||||
f"of {storage_limit:,} bytes ({used_pct:.1f}%)"
|
||||
)
|
||||
if projected_usage / storage_limit >= 0.8:
|
||||
logger.warning(
|
||||
f"User {self.user_id} workspace storage at "
|
||||
f"{projected_usage / storage_limit * 100:.1f}% "
|
||||
f"({projected_usage} / {storage_limit} bytes)"
|
||||
)
|
||||
|
||||
# Scan here — callers must NOT duplicate this scan.
|
||||
# WorkspaceManager owns virus scanning for all persisted files.
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Check if file exists at path (only error for non-overwrite case)
|
||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||
|
||||
@@ -171,3 +171,103 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
|
||||
)
|
||||
|
||||
mock_storage.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_quota_exceeded_raises_value_error(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""write_file raises ValueError when workspace storage quota is exceeded."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024, # 250 MB limit
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=250 * 1024 * 1024, # already at limit
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=b"hello")
|
||||
|
||||
# Storage should NOT have been written to
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_80pct_warning_logged(manager, mock_storage, mock_db, caplog):
|
||||
"""write_file logs a warning when workspace usage crosses 80%."""
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit_bytes = 100 # 100 bytes total limit
|
||||
current_usage = 75 # 75 bytes used → 75% before write
|
||||
content = b"123456" # 6 bytes → 81% after write
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit_bytes,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=current_usage,
|
||||
),
|
||||
):
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="backend.util.workspace"):
|
||||
await manager.write_file(filename="test.txt", content=content)
|
||||
|
||||
assert any("workspace storage at" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_not_double_counted(manager, mock_storage, mock_db):
|
||||
"""Overwriting a file subtracts the old file size from usage check."""
|
||||
existing_file = _make_workspace_file(size_bytes=50)
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit_bytes = 100
|
||||
current_usage = 90 # 90 bytes used, 50 of which is the file being replaced
|
||||
content = b"x" * 50 # replacing with same-size file — should succeed
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit_bytes,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=current_usage,
|
||||
),
|
||||
):
|
||||
# Should NOT raise — net usage after overwrite is 90 - 50 + 50 = 90, under 100
|
||||
result = await manager.write_file(
|
||||
filename="test.txt", content=content, overwrite=True
|
||||
)
|
||||
assert result == created_file
|
||||
|
||||
Reference in New Issue
Block a user