fix(backend): address PR review — scan bug, quota overwrite, test coverage

- Fix VirusScanError (ClamAV outage) returning 409 instead of 500
- Fix duplicate virus scan in CoPilot WriteWorkspaceFileTool
- Fix overwrite quota miscalculation (subtract old file size from usage)
- Add parametrized tests for get_workspace_storage_limit_bytes (all tiers)
- Add test for quota rejection path in WorkspaceManager.write_file()
- Add test for 80% warning log
- Add test for VirusScanError → 500 mapping
- Add test for get_storage_usage endpoint with tier-based limit
- Extract projected_usage variable, remove redundant truthiness guard

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nicholas Tindle
2026-04-14 14:59:11 -05:00
parent 406ea4213d
commit 5df958d302
6 changed files with 232 additions and 24 deletions

View File

@@ -14,7 +14,7 @@ from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel, Field
from backend.api.features.store.exceptions import VirusDetectedError
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
from backend.data.workspace import (
WorkspaceFile,
@@ -250,6 +250,8 @@ async def upload_file(
)
except VirusDetectedError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e)) from e
except VirusScanError as e:
raise fastapi.HTTPException(status_code=500, detail=str(e)) from e
except ValueError as e:
# write_file raises ValueError for path-conflict, size-limit, and
# storage-quota cases; map each to its correct HTTP status.

View File

@@ -582,3 +582,51 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)
def test_upload_virus_scan_infrastructure_error_returns_500(mocker):
"""VirusScanError (ClamAV outage) should return 500, not 409."""
from backend.api.features.store.exceptions import VirusScanError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=VirusScanError("ClamAV connection refused"),
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 500
def test_get_storage_usage_returns_tier_based_limit(mocker):
"""get_storage_usage should return the user's tier-based limit, not a static config."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=100 * 1024 * 1024, # 100 MB used
)
mocker.patch(
"backend.api.features.workspace.routes.count_workspace_files",
return_value=5,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=1024 * 1024 * 1024, # 1 GB (PRO tier)
)
response = client.get("/storage/usage")
assert response.status_code == 200
data = response.json()
assert data["limit_bytes"] == 1024 * 1024 * 1024
assert data["used_bytes"] == 100 * 1024 * 1024
assert data["file_count"] == 5

View File

@@ -9,6 +9,7 @@ from redis.exceptions import RedisError
from .rate_limit import (
DEFAULT_TIER,
TIER_MULTIPLIERS,
TIER_WORKSPACE_STORAGE_MB,
CoPilotUsageStatus,
RateLimitExceeded,
SubscriptionTier,
@@ -23,6 +24,7 @@ from .rate_limit import (
get_global_rate_limits,
get_usage_status,
get_user_tier,
get_workspace_storage_limit_bytes,
increment_daily_reset_count,
record_token_usage,
release_reset_lock,
@@ -1367,3 +1369,52 @@ class TestResetUserUsage:
):
with pytest.raises(RedisError):
await reset_user_usage("user-1")
class TestWorkspaceStorageLimits:
"""Tests for tier-based workspace storage limits."""
@pytest.mark.parametrize(
"tier,expected_mb",
[
(SubscriptionTier.FREE, 250),
(SubscriptionTier.PRO, 1024),
(SubscriptionTier.BUSINESS, 5 * 1024),
(SubscriptionTier.ENTERPRISE, 15 * 1024),
],
)
def test_tier_workspace_storage_mapping_covers_all_tiers(self, tier, expected_mb):
"""Every tier has an explicit storage limit in the mapping."""
assert tier in TIER_WORKSPACE_STORAGE_MB
assert TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tier,expected_bytes",
[
(SubscriptionTier.FREE, 250 * 1024 * 1024),
(SubscriptionTier.PRO, 1024 * 1024 * 1024),
(SubscriptionTier.BUSINESS, 5 * 1024 * 1024 * 1024),
(SubscriptionTier.ENTERPRISE, 15 * 1024 * 1024 * 1024),
],
)
async def test_get_workspace_storage_limit_bytes_per_tier(
self, tier, expected_bytes
):
"""get_workspace_storage_limit_bytes returns correct bytes for each tier."""
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value=tier,
):
result = await get_workspace_storage_limit_bytes("user-1")
assert result == expected_bytes
@pytest.mark.asyncio
async def test_get_workspace_storage_limit_bytes_defaults_to_free_on_unknown(self):
"""Unknown tier falls back to FREE tier limit."""
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value="UNKNOWN_TIER",
):
result = await get_workspace_storage_limit_bytes("user-1")
assert result == 250 * 1024 * 1024

View File

@@ -19,7 +19,6 @@ from backend.copilot.context import (
from backend.copilot.model import ChatSession
from backend.copilot.tools.sandbox import make_session_path
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
@@ -837,7 +836,6 @@ class WriteWorkspaceFileTool(BaseTool):
)
try:
await scan_content_safe(content_bytes, filename=filename)
manager = await get_workspace_manager(user_id, session_id)
rec = await manager.write_file(
content=content_bytes,

View File

@@ -186,27 +186,7 @@ class WorkspaceManager:
f"{Config().max_file_size_mb}MB limit"
)
# Enforce per-user workspace storage quota (tier-based).
storage_limit = await get_workspace_storage_limit_bytes(self.user_id)
current_usage = await get_workspace_total_size(self.workspace_id)
if current_usage + len(content) > storage_limit:
used_pct = (current_usage / storage_limit) * 100
raise ValueError(
f"Storage limit exceeded: {current_usage:,} bytes used "
f"of {storage_limit:,} bytes ({used_pct:.1f}%)"
)
if storage_limit and (current_usage + len(content)) / storage_limit >= 0.8:
logger.warning(
f"User {self.user_id} workspace storage at "
f"{(current_usage + len(content)) / storage_limit * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit} bytes)"
)
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
await scan_content_safe(content, filename=filename)
# Determine path with session scoping
# Determine path with session scoping (needed before quota check for overwrites)
if path is None:
path = f"/{filename}"
elif not path.startswith("/"):
@@ -215,6 +195,35 @@ class WorkspaceManager:
# Resolve path with session prefix
path = self._resolve_path(path)
# Enforce per-user workspace storage quota (tier-based).
# For overwrites, subtract the existing file's size so replacing a file
# with a same-size or smaller file is not rejected near the cap.
storage_limit = await get_workspace_storage_limit_bytes(self.user_id)
current_usage = await get_workspace_total_size(self.workspace_id)
if overwrite:
db = workspace_db()
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
if existing is not None:
current_usage = max(0, current_usage - existing.size_bytes)
projected_usage = current_usage + len(content)
if projected_usage > storage_limit:
used_pct = (current_usage / storage_limit) * 100
raise ValueError(
f"Storage limit exceeded: {current_usage:,} bytes used "
f"of {storage_limit:,} bytes ({used_pct:.1f}%)"
)
if projected_usage / storage_limit >= 0.8:
logger.warning(
f"User {self.user_id} workspace storage at "
f"{projected_usage / storage_limit * 100:.1f}% "
f"({projected_usage} / {storage_limit} bytes)"
)
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
await scan_content_safe(content, filename=filename)
# Check if file exists at path (only error for non-overwrite case)
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
# This ensures the new file is written to storage BEFORE the old one is deleted,

View File

@@ -171,3 +171,103 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
)
mock_storage.delete.assert_called_once()
@pytest.mark.asyncio
async def test_write_file_quota_exceeded_raises_value_error(
manager, mock_storage, mock_db
):
"""write_file raises ValueError when workspace storage quota is exceeded."""
mock_db.get_workspace_file_by_path.return_value = None
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024, # 250 MB limit
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=250 * 1024 * 1024, # already at limit
),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(filename="test.txt", content=b"hello")
# Storage should NOT have been written to
mock_storage.store.assert_not_called()
@pytest.mark.asyncio
async def test_write_file_80pct_warning_logged(manager, mock_storage, mock_db, caplog):
"""write_file logs a warning when workspace usage crosses 80%."""
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = None
mock_db.create_workspace_file.return_value = created_file
limit_bytes = 100 # 100 bytes total limit
current_usage = 75 # 75 bytes used → 75% before write
content = b"123456" # 6 bytes → 81% after write
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit_bytes,
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=current_usage,
),
):
import logging
with caplog.at_level(logging.WARNING, logger="backend.util.workspace"):
await manager.write_file(filename="test.txt", content=content)
assert any("workspace storage at" in record.message for record in caplog.records)
@pytest.mark.asyncio
async def test_write_file_overwrite_not_double_counted(manager, mock_storage, mock_db):
"""Overwriting a file subtracts the old file size from usage check."""
existing_file = _make_workspace_file(size_bytes=50)
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = existing_file
mock_db.create_workspace_file.return_value = created_file
limit_bytes = 100
current_usage = 90 # 90 bytes used, 50 of which is the file being replaced
content = b"x" * 50 # replacing with same-size file — should succeed
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit_bytes,
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=current_usage,
),
):
# Should NOT raise — net usage after overwrite is 90 - 50 + 50 = 90, under 100
result = await manager.write_file(
filename="test.txt", content=content, overwrite=True
)
assert result == created_file