mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
23 Commits
master
...
pr/12780/b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe9d53f520 | ||
|
|
489ccf96f5 | ||
|
|
db66f34cf4 | ||
|
|
bf044a2634 | ||
|
|
33eb9e9ad9 | ||
|
|
938cf9ea5f | ||
|
|
1fed970a88 | ||
|
|
4d85a68d9a | ||
|
|
f3b4b0e1d9 | ||
|
|
64b7560291 | ||
|
|
7351fba17a | ||
|
|
5c5dd3733e | ||
|
|
6f80a109ef | ||
|
|
81e4403216 | ||
|
|
6905038081 | ||
|
|
0af8f239f0 | ||
|
|
6794908692 | ||
|
|
9aa72fb46d | ||
|
|
c06880412d | ||
|
|
ad971698a0 | ||
|
|
a81f305a74 | ||
|
|
5df958d302 | ||
|
|
406ea4213d |
@@ -2,6 +2,7 @@
|
||||
Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -14,6 +15,8 @@ from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
|
||||
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
@@ -24,8 +27,7 @@ from backend.data.workspace import (
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace import WorkspaceManager, format_bytes
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -249,50 +251,28 @@ async def upload_file(
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
# Write file via WorkspaceManager (handles virus scan, per-file size,
|
||||
# and per-user tier-based storage quota internally).
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
except VirusDetectedError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e)) from e
|
||||
except VirusScanError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e)) from e
|
||||
except ValueError as e:
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
# write_file raises ValueError for path-conflict, size-limit, and
|
||||
# storage-quota cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
if message.startswith(("File too large", "Storage limit exceeded")):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
storage_limit_bytes = await get_workspace_storage_limit_bytes(user_id)
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
try:
|
||||
@@ -304,11 +284,12 @@ async def upload_file(
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
detail=(
|
||||
f"Storage limit exceeded. "
|
||||
f"You've used {format_bytes(new_total)} of your "
|
||||
f"{format_bytes(storage_limit_bytes)} quota. "
|
||||
f"Delete some files or upgrade your plan for more storage."
|
||||
),
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
@@ -331,12 +312,13 @@ async def get_storage_usage(
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
used_bytes, file_count, limit_bytes = await asyncio.gather(
|
||||
get_workspace_total_size(workspace.id),
|
||||
count_workspace_files(workspace.id),
|
||||
get_workspace_storage_limit_bytes(user_id),
|
||||
)
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
|
||||
@@ -151,15 +151,16 @@ def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_storage_limit_bytes")
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
mock_manager_cls, mock_total_size, mock_get_workspace, mock_storage_limit
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_storage_limit.return_value = 250 * 1024 * 1024
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
@@ -178,10 +179,9 @@ def test_upload_passes_user_upload_origin_metadata(
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
mock_manager_cls, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
@@ -234,8 +234,8 @@ def test_upload_happy_path(mocker):
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
@@ -256,20 +256,24 @@ def test_upload_exceeds_max_file_size(mocker):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
"""WorkspaceManager.write_file raises ValueError when quota exceeded → 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("Storage limit exceeded: 500 MB used of 250 MB (200.0%)")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
@@ -283,13 +287,14 @@ def test_upload_post_write_quota_race(mocker):
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Post-write total exceeds the tier-based limit (250 MB for FREE).
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
return_value=600 * 1024 * 1024,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
@@ -318,8 +323,8 @@ def test_upload_any_extension(mocker):
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
@@ -333,23 +338,17 @@ def test_upload_any_extension(mocker):
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
"""Files flagged by ClamAV should be rejected via WorkspaceManager."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -357,7 +356,6 @@ def test_upload_blocked_by_virus_scan(mocker):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker):
|
||||
@@ -371,8 +369,8 @@ def test_upload_file_without_extension(mocker):
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
@@ -402,8 +400,8 @@ def test_upload_strips_path_components(mocker):
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
@@ -488,14 +486,6 @@ def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
@@ -516,14 +506,6 @@ def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
@@ -602,6 +584,54 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
)
|
||||
|
||||
|
||||
def test_upload_virus_scan_infrastructure_error_returns_500(mocker):
|
||||
"""VirusScanError (ClamAV outage) should return 500, not 409."""
|
||||
from backend.api.features.store.exceptions import VirusScanError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=VirusScanError("ClamAV connection refused"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_storage_usage_returns_tier_based_limit(mocker):
|
||||
"""get_storage_usage should return the user's tier-based limit, not a static config."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=100 * 1024 * 1024, # 100 MB used
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.count_workspace_files",
|
||||
return_value=5,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
|
||||
return_value=1024 * 1024 * 1024, # 1 GB (PRO tier)
|
||||
)
|
||||
|
||||
response = client.get("/storage/usage")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["limit_bytes"] == 1024 * 1024 * 1024
|
||||
assert data["used_bytes"] == 100 * 1024 * 1024
|
||||
assert data["file_count"] == 5
|
||||
|
||||
|
||||
# -- _sanitize_filename_for_header tests --
|
||||
|
||||
|
||||
|
||||
@@ -111,6 +111,19 @@ TIER_MULTIPLIERS = _DEFAULT_TIER_MULTIPLIERS
|
||||
DEFAULT_TIER = SubscriptionTier.NO_TIER
|
||||
|
||||
|
||||
# Per-tier workspace storage caps in MB. NO_TIER keeps the same baseline as
|
||||
# BASIC so users who cancel retain a small quota and see a real overage cap,
|
||||
# while LaunchDarkly can still tune tiers without a deploy.
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.NO_TIER: 250, # 250 MB
|
||||
SubscriptionTier.BASIC: 250, # 250 MB
|
||||
SubscriptionTier.PRO: 1024, # 1 GB
|
||||
SubscriptionTier.MAX: 5 * 1024, # 5 GB
|
||||
SubscriptionTier.BUSINESS: 15 * 1024, # 15 GB
|
||||
SubscriptionTier.ENTERPRISE: 15 * 1024, # 15 GB
|
||||
}
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
|
||||
async def _fetch_tier_multipliers_flag() -> dict[SubscriptionTier, float] | None:
|
||||
"""Fetch the ``copilot-tier-multipliers`` LD flag and parse it.
|
||||
@@ -232,6 +245,82 @@ async def get_tier_multipliers() -> dict[str, float]:
|
||||
return {tier.value: multiplier for tier, multiplier in merged.items()}
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
|
||||
async def _fetch_workspace_storage_limits_flag() -> dict[SubscriptionTier, int] | None:
|
||||
"""Fetch the ``copilot-tier-workspace-storage-limits`` LD flag and parse it.
|
||||
|
||||
Returns a sparse ``{tier: megabytes}`` map built from whichever keys are
|
||||
valid in the flag payload, or ``None`` when the flag is unset / invalid /
|
||||
LD is unavailable. Callers merge whatever survives into
|
||||
:data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
|
||||
|
||||
The LD value is expected to be a JSON object keyed by tier enum name
|
||||
(``{"NO_TIER": 250, "PRO": 1024, "BUSINESS": 15360}``). Non-int,
|
||||
negative, or zero values are skipped so a broken key degrades to the
|
||||
code default instead of wiping out the limit. Zero is rejected because
|
||||
downstream guards treat ``storage_limit == 0`` as "uncapped".
|
||||
"""
|
||||
# Lazy import: rate_limit -> feature_flag -> settings -> ... -> rate_limit.
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_TIER_WORKSPACE_STORAGE_LIMITS.value, "system", None
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning(
|
||||
"Invalid LD value for copilot-tier-workspace-storage-limits "
|
||||
"(expected JSON object): %r",
|
||||
raw,
|
||||
)
|
||||
return None
|
||||
|
||||
parsed: dict[SubscriptionTier, int] = {}
|
||||
for key, value in raw.items():
|
||||
try:
|
||||
tier = SubscriptionTier(key)
|
||||
except ValueError:
|
||||
continue
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
logger.warning(
|
||||
"Invalid LD value for copilot-tier-workspace-storage-limits[%s]: %r",
|
||||
key,
|
||||
value,
|
||||
)
|
||||
continue
|
||||
if value <= 0:
|
||||
logger.warning(
|
||||
"Non-positive LD value for copilot-tier-workspace-storage-limits[%s]: %r",
|
||||
key,
|
||||
value,
|
||||
)
|
||||
continue
|
||||
parsed[tier] = value
|
||||
return parsed or None
|
||||
|
||||
|
||||
async def get_workspace_storage_limits_mb() -> dict[str, int]:
|
||||
"""Return the effective ``{tier_value: megabytes}`` workspace limit map.
|
||||
|
||||
Honours the ``copilot-tier-workspace-storage-limits`` LD flag when set;
|
||||
missing tiers inherit :data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
|
||||
Unparseable flag values or LD fetch failures fall back to the defaults.
|
||||
"""
|
||||
try:
|
||||
override = await _fetch_workspace_storage_limits_flag()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_workspace_storage_limits_mb: LD lookup failed", exc_info=True
|
||||
)
|
||||
override = None
|
||||
|
||||
merged: dict[SubscriptionTier, int] = dict(_DEFAULT_TIER_WORKSPACE_STORAGE_MB)
|
||||
if override:
|
||||
merged.update(override)
|
||||
return {tier.value: megabytes for tier, megabytes in merged.items()}
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window.
|
||||
|
||||
@@ -661,6 +750,19 @@ get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-de
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_workspace_storage_limit_bytes(user_id: str) -> int:
|
||||
"""Return the workspace storage cap in bytes for the user's subscription tier."""
|
||||
tier = await get_user_tier(user_id)
|
||||
limits_mb = await get_workspace_storage_limits_mb()
|
||||
tier_key = getattr(tier, "value", str(tier))
|
||||
fallback_mb = limits_mb.get(
|
||||
DEFAULT_TIER.value,
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB[DEFAULT_TIER],
|
||||
)
|
||||
mb = limits_mb.get(tier_key, fallback_mb)
|
||||
return mb * 1024 * 1024
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
_DEFAULT_TIER_MULTIPLIERS,
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB,
|
||||
DEFAULT_TIER,
|
||||
TIER_MULTIPLIERS,
|
||||
CoPilotUsageStatus,
|
||||
@@ -18,6 +19,7 @@ from .rate_limit import (
|
||||
_daily_reset_time,
|
||||
_fetch_cost_limits_flag,
|
||||
_fetch_tier_multipliers_flag,
|
||||
_fetch_workspace_storage_limits_flag,
|
||||
_weekly_key,
|
||||
_weekly_reset_time,
|
||||
acquire_reset_lock,
|
||||
@@ -27,6 +29,8 @@ from .rate_limit import (
|
||||
get_tier_multipliers,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
get_workspace_storage_limit_bytes,
|
||||
get_workspace_storage_limits_mb,
|
||||
increment_daily_reset_count,
|
||||
record_cost_usage,
|
||||
release_reset_lock,
|
||||
@@ -471,6 +475,85 @@ class TestGetTierMultipliers:
|
||||
assert result == {t.value: m for t, m in _DEFAULT_TIER_MULTIPLIERS.items()}
|
||||
|
||||
|
||||
class TestGetWorkspaceStorageLimits:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_flag_cache(self):
|
||||
"""Clear the LD flag cache between tests so patches don't leak."""
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_when_flag_unset(self):
|
||||
"""With no LD override, the resolver returns the default map."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_override(self):
|
||||
"""LD override populates targeted tiers; others inherit defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"NO_TIER": 300, "PRO": 2048},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["NO_TIER"] == 300
|
||||
assert result["PRO"] == 2048
|
||||
assert (
|
||||
result["BASIC"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
|
||||
)
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_falls_back(self):
|
||||
"""A non-object LD value falls back to defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value="broken",
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tier_key_and_invalid_values_skipped(self):
|
||||
"""Unknown tiers and invalid values degrade to defaults per key."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"NO_TIER": 300, "BOGUS": 999, "MAX": -1, "BUSINESS": "nope"},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["NO_TIER"] == 300
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
assert (
|
||||
result["BUSINESS"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BUSINESS]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_failure_falls_back(self):
|
||||
"""LD lookup raising propagates to defaults, not up the call stack."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LD SDK not initialized"),
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_global_rate_limits — LD-flag cost limits parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1792,3 +1875,284 @@ class TestResetUserUsage:
|
||||
):
|
||||
with pytest.raises(RedisError):
|
||||
await reset_user_usage("user-1")
|
||||
|
||||
|
||||
class TestWorkspaceStorageLimits:
|
||||
"""Tests for tier-based workspace storage limits."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_flag_cache(self):
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
def test_every_subscription_tier_has_storage_limit(self):
|
||||
"""Adding a new SubscriptionTier without a storage limit should fail."""
|
||||
for tier in SubscriptionTier:
|
||||
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB, (
|
||||
f"SubscriptionTier.{tier.name} has no entry in "
|
||||
f"_DEFAULT_TIER_WORKSPACE_STORAGE_MB — add one"
|
||||
)
|
||||
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] > 0
|
||||
|
||||
def test_every_subscription_tier_has_rate_limit_multiplier(self):
|
||||
"""Adding a new SubscriptionTier without a rate limit multiplier should fail."""
|
||||
for tier in SubscriptionTier:
|
||||
assert tier in _DEFAULT_TIER_MULTIPLIERS, (
|
||||
f"SubscriptionTier.{tier.name} has no entry in "
|
||||
f"_DEFAULT_TIER_MULTIPLIERS — add one"
|
||||
)
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
assert _DEFAULT_TIER_MULTIPLIERS[tier] == 0.0
|
||||
else:
|
||||
assert _DEFAULT_TIER_MULTIPLIERS[tier] > 0
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_mb",
|
||||
[
|
||||
(SubscriptionTier.NO_TIER, 250),
|
||||
(SubscriptionTier.BASIC, 250),
|
||||
(SubscriptionTier.PRO, 1024),
|
||||
(SubscriptionTier.MAX, 5 * 1024),
|
||||
(SubscriptionTier.BUSINESS, 15 * 1024),
|
||||
(SubscriptionTier.ENTERPRISE, 15 * 1024),
|
||||
],
|
||||
)
|
||||
def test_tier_workspace_storage_mapping_covers_all_tiers(self, tier, expected_mb):
|
||||
"""Every tier has an explicit storage limit in the mapping."""
|
||||
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB
|
||||
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_bytes",
|
||||
[
|
||||
(SubscriptionTier.NO_TIER, 250 * 1024 * 1024),
|
||||
(SubscriptionTier.BASIC, 250 * 1024 * 1024),
|
||||
(SubscriptionTier.PRO, 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.MAX, 5 * 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.BUSINESS, 15 * 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.ENTERPRISE, 15 * 1024 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
async def test_get_workspace_storage_limit_bytes_per_tier(
|
||||
self, tier, expected_bytes
|
||||
):
|
||||
"""get_workspace_storage_limit_bytes returns correct bytes for each tier."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value=tier,
|
||||
):
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
assert result == expected_bytes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_storage_limit_bytes_defaults_to_default_tier_on_unknown(
|
||||
self,
|
||||
):
|
||||
"""Unknown tier falls back to the default tier limit."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value="UNKNOWN_TIER",
|
||||
):
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
assert result == 250 * 1024 * 1024
|
||||
|
||||
|
||||
class TestWorkspaceStorageLimitsAdversarial:
|
||||
"""Adversarial edge-case tests for workspace storage limit resolution."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_flag_cache(self):
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_zero_values_are_rejected(self):
|
||||
"""LD zero values must be rejected — 0 is the 'uncapped' sentinel downstream."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"BASIC": 0, "PRO": 0},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
# Zero rejected → falls back to code defaults
|
||||
assert (
|
||||
result["BASIC"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
|
||||
)
|
||||
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_zero_override_does_not_bypass_quota(self):
|
||||
"""A zero LD override must not result in 0 bytes (which bypasses quota)."""
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"BASIC": 0},
|
||||
),
|
||||
):
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
# Zero rejected → falls back to default 250 MB in bytes
|
||||
assert result == 250 * 1024 * 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_bool_values_are_rejected(self):
|
||||
"""Booleans masquerading as ints (True=1, False=0) must be rejected."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"PRO": True, "MAX": False},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
# PRO and MAX should keep their defaults, not become 1 and 0
|
||||
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_float_values_are_rejected(self):
|
||||
"""Floats in LD payload should be rejected — must be int."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"PRO": 1024.5},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_empty_object_falls_back_to_defaults(self):
|
||||
"""Empty LD payload {} → parsed is empty → None → defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_list_instead_of_dict_falls_back(self):
|
||||
"""LD returns a list instead of dict → rejected → defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[250, 1024, 5120],
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_extremely_large_values_accepted(self):
|
||||
"""Very large LD values shouldn't be clipped — trust LD config."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"ENTERPRISE": 999_999_999},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["ENTERPRISE"] == 999_999_999
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_mixed_valid_and_garbage_keys(self):
|
||||
"""Mixed payload: valid overrides applied, garbage ignored."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={
|
||||
"PRO": 2048,
|
||||
"": 100,
|
||||
"null": 100,
|
||||
"NO_TIER": -5,
|
||||
"BASIC": "abc",
|
||||
"MAX": True,
|
||||
"ENTERPRISE": 30720,
|
||||
},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
# Only PRO and ENTERPRISE should be overridden
|
||||
assert result["PRO"] == 2048
|
||||
assert result["ENTERPRISE"] == 30720
|
||||
# Everything else keeps defaults
|
||||
assert (
|
||||
result["BASIC"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
|
||||
)
|
||||
assert (
|
||||
result["NO_TIER"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.NO_TIER]
|
||||
)
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_transition_pro_to_no_tier_reduces_limit(self):
|
||||
"""Simulating unsubscribe: PRO user → NO_TIER gets lower limit."""
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
# PRO tier
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value=SubscriptionTier.PRO,
|
||||
):
|
||||
pro_limit = await get_workspace_storage_limit_bytes("user-1")
|
||||
# Unsubscribe → NO_TIER
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value=SubscriptionTier.NO_TIER,
|
||||
):
|
||||
no_tier_limit = await get_workspace_storage_limit_bytes("user-1")
|
||||
assert pro_limit == 1024 * 1024 * 1024 # 1 GB
|
||||
assert no_tier_limit == 250 * 1024 * 1024 # 250 MB
|
||||
assert no_tier_limit < pro_limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_tiers_are_monotonically_increasing(self):
|
||||
"""Storage limits should be monotonically non-decreasing across tiers."""
|
||||
tier_order = [
|
||||
SubscriptionTier.NO_TIER,
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
SubscriptionTier.ENTERPRISE,
|
||||
]
|
||||
limits = [_DEFAULT_TIER_WORKSPACE_STORAGE_MB[t] for t in tier_order]
|
||||
for i in range(1, len(limits)):
|
||||
assert limits[i] >= limits[i - 1], (
|
||||
f"{tier_order[i].name} ({limits[i]} MB) < "
|
||||
f"{tier_order[i - 1].name} ({limits[i - 1]} MB)"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_flag_fetches_share_cache(self):
|
||||
"""Multiple concurrent calls should only hit LD once (caching)."""
|
||||
call_count = 0
|
||||
|
||||
async def counting_flag(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {"PRO": 2048}
|
||||
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=counting_flag,
|
||||
):
|
||||
import asyncio
|
||||
|
||||
results = await asyncio.gather(
|
||||
get_workspace_storage_limits_mb(),
|
||||
get_workspace_storage_limits_mb(),
|
||||
get_workspace_storage_limits_mb(),
|
||||
)
|
||||
# All results should be identical
|
||||
assert all(r == results[0] for r in results)
|
||||
# Flag should have been fetched at most once due to caching
|
||||
assert call_count <= 1
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
@@ -19,7 +20,6 @@ from backend.copilot.context import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -837,7 +837,6 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await scan_content_safe(content_bytes, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content_bytes,
|
||||
@@ -895,6 +894,12 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except VirusDetectedError as e:
|
||||
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except VirusScanError as e:
|
||||
logger.error(f"Virus scan infrastructure error: {e}", exc_info=True)
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except ValueError as e:
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
|
||||
@@ -623,3 +623,73 @@ async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_
|
||||
# Normal workspace path must have produced a content response.
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert base64.b64decode(result.content_base64) == fake_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WriteWorkspaceFileTool exception handling (quota, virus, scan errors)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestWriteWorkspaceFileToolErrorHandling:
|
||||
"""Verify WriteWorkspaceFileTool returns proper ErrorResponse for exceptions."""
|
||||
|
||||
async def _execute_write(self, side_effect, setup_test_data):
|
||||
"""Helper: run WriteWorkspaceFileTool with mocked write_file."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.write_file = AsyncMock(side_effect=side_effect)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.workspace_files.get_workspace_manager",
|
||||
AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
tool = WriteWorkspaceFileTool()
|
||||
return await tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="test.txt",
|
||||
content="hello",
|
||||
)
|
||||
|
||||
async def test_quota_exceeded_returns_error_response(self, setup_test_data):
|
||||
"""Quota exceeded (ValueError) → ErrorResponse with storage message."""
|
||||
result = await self._execute_write(
|
||||
ValueError("Storage limit exceeded: 250 MB used of 250 MB (100.0%)"),
|
||||
setup_test_data,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Storage limit exceeded" in result.message
|
||||
|
||||
async def test_virus_detected_returns_error_response(self, setup_test_data):
|
||||
"""VirusDetectedError → ErrorResponse with virus message."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
result = await self._execute_write(
|
||||
VirusDetectedError("Eicar-Test-Signature"),
|
||||
setup_test_data,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Virus detected" in result.message
|
||||
|
||||
async def test_virus_scan_error_returns_error_response(self, setup_test_data):
|
||||
"""VirusScanError (infra failure) → ErrorResponse with scan message."""
|
||||
from backend.api.features.store.exceptions import VirusScanError
|
||||
|
||||
result = await self._execute_write(
|
||||
VirusScanError("ClamAV connection refused"),
|
||||
setup_test_data,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "ClamAV" in result.message
|
||||
|
||||
async def test_file_conflict_returns_error_response(self, setup_test_data):
|
||||
"""File path conflict (ValueError) → ErrorResponse."""
|
||||
result = await self._execute_write(
|
||||
ValueError("File already exists at path: /sessions/abc/test.txt"),
|
||||
setup_test_data,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "already exists" in result.message
|
||||
|
||||
@@ -204,6 +204,68 @@ async def test_sync_subscription_from_stripe_cancelled():
|
||||
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_cancelled_applies_no_tier_storage_limit():
|
||||
"""After unsubscribe takes effect, workspace storage resolves against NO_TIER."""
|
||||
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
|
||||
|
||||
mock_user = _make_user(tier=SubscriptionTier.PRO)
|
||||
stripe_sub = {
|
||||
"id": "sub_old",
|
||||
"customer": "cus_123",
|
||||
"status": "canceled",
|
||||
"items": {"data": []},
|
||||
}
|
||||
empty_list = MagicMock()
|
||||
empty_list.data = []
|
||||
empty_list.has_more = False
|
||||
|
||||
async def _set_tier(_user_id: str, tier: SubscriptionTier) -> None:
|
||||
mock_user.subscriptionTier = tier
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
return_value=empty_list,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_set_tier,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda _user_id: mock_user.subscriptionTier,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_workspace_storage_limits_mb",
|
||||
new_callable=AsyncMock,
|
||||
return_value={
|
||||
"NO_TIER": 250,
|
||||
"BASIC": 500,
|
||||
"PRO": 1024,
|
||||
"MAX": 5 * 1024,
|
||||
"BUSINESS": 15 * 1024,
|
||||
"ENTERPRISE": 15 * 1024,
|
||||
},
|
||||
),
|
||||
patch.object(
|
||||
get_pending_subscription_change,
|
||||
"cache_delete",
|
||||
) as mock_pending_cache_delete,
|
||||
):
|
||||
await sync_subscription_from_stripe(stripe_sub)
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
|
||||
assert result == 250 * 1024 * 1024
|
||||
mock_pending_cache_delete.assert_called_once_with("user-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_cancelled_but_other_active_sub_exists():
|
||||
"""Cancelling sub_old must NOT downgrade the user if sub_new is still active.
|
||||
|
||||
@@ -44,6 +44,7 @@ class Flag(str, Enum):
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_COST_LIMITS = "copilot-cost-limits"
|
||||
COPILOT_TIER_MULTIPLIERS = "copilot-tier-multipliers"
|
||||
COPILOT_TIER_WORKSPACE_STORAGE_LIMITS = "copilot-tier-workspace-storage-limits"
|
||||
COPILOT_TIER_STRIPE_PRICES = "copilot-tier-stripe-prices"
|
||||
GRAPHITI_MEMORY = "graphiti-memory"
|
||||
# Stripe Product ID for top-up Checkout sessions. When unset (default),
|
||||
|
||||
@@ -441,13 +441,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||
)
|
||||
|
||||
max_workspace_storage_mb: int = Field(
|
||||
default=500,
|
||||
ge=1,
|
||||
le=10240,
|
||||
description="Maximum total workspace storage per user in MB.",
|
||||
)
|
||||
|
||||
# AutoMod configuration
|
||||
automod_enabled: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides a high-level interface for workspace file operations,
|
||||
combining the storage backend and database layer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
@@ -12,12 +13,28 @@ from typing import Optional
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace_total_size
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
|
||||
|
||||
def format_bytes(n: int) -> str:
|
||||
"""Format bytes as a human-readable string (e.g. 250 MB, 1.0 GB)."""
|
||||
KB, MB, GB = 1024, 1024**2, 1024**3
|
||||
if n < KB:
|
||||
return f"{n} B"
|
||||
if n < MB:
|
||||
kb = round(n / KB)
|
||||
return f"{n / MB:.1f} MB" if kb >= 1024 else f"{kb} KB"
|
||||
if n < GB:
|
||||
mb = round(n / MB)
|
||||
return f"{n / GB:.1f} GB" if mb >= 1024 else f"{mb} MB"
|
||||
return f"{n / GB:.1f} GB"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -185,11 +202,7 @@ class WorkspaceManager:
|
||||
f"{Config().max_file_size_mb}MB limit"
|
||||
)
|
||||
|
||||
# Scan here — callers must NOT duplicate this scan.
|
||||
# WorkspaceManager owns virus scanning for all persisted files.
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Determine path with session scoping
|
||||
# Determine path with session scoping (needed before quota check for overwrites)
|
||||
if path is None:
|
||||
path = f"/{filename}"
|
||||
elif not path.startswith("/"):
|
||||
@@ -198,10 +211,37 @@ class WorkspaceManager:
|
||||
# Resolve path with session prefix
|
||||
path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists at path (only error for non-overwrite case)
|
||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||
# preventing data loss if the new write fails
|
||||
# Enforce per-user workspace storage quota (tier-based).
|
||||
# For overwrites, subtract the existing file's size so replacing a file
|
||||
# with a same-size or smaller file is not rejected near the cap.
|
||||
storage_limit, current_usage = await asyncio.gather(
|
||||
get_workspace_storage_limit_bytes(self.user_id),
|
||||
get_workspace_total_size(self.workspace_id),
|
||||
)
|
||||
if overwrite:
|
||||
db = workspace_db()
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
current_usage = max(0, current_usage - existing.size_bytes)
|
||||
|
||||
projected_usage = current_usage + len(content)
|
||||
if storage_limit > 0 and projected_usage > storage_limit:
|
||||
raise ValueError(
|
||||
f"Storage limit exceeded. "
|
||||
f"You've used {format_bytes(current_usage)} of your "
|
||||
f"{format_bytes(storage_limit)} quota. "
|
||||
f"Delete some files or upgrade your plan for more storage."
|
||||
)
|
||||
if storage_limit > 0 and projected_usage / storage_limit >= 0.8:
|
||||
logger.warning(
|
||||
f"User {self.user_id} workspace storage at "
|
||||
f"{projected_usage / storage_limit * 100:.1f}% "
|
||||
f"({projected_usage} / {storage_limit} bytes)"
|
||||
)
|
||||
|
||||
# Check if file exists at path (only error for non-overwrite case).
|
||||
# Done before virus scanning so a cheap duplicate-path check isn't
|
||||
# blocked by a potentially slow or unavailable scanner.
|
||||
db = workspace_db()
|
||||
|
||||
if not overwrite:
|
||||
@@ -209,6 +249,10 @@ class WorkspaceManager:
|
||||
if existing is not None:
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
# Scan here — callers must NOT duplicate this scan.
|
||||
# WorkspaceManager owns virus scanning for all persisted files.
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Auto-detect MIME type if not provided
|
||||
if mime_type is None:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
|
||||
@@ -88,6 +88,11 @@ async def test_write_file_no_overwrite_unique_violation_raises_and_cleans_up(
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
|
||||
):
|
||||
with pytest.raises(ValueError, match="File already exists"):
|
||||
await manager.write_file(
|
||||
@@ -115,6 +120,11 @@ async def test_write_file_overwrite_conflict_then_retry_succeeds(
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
|
||||
patch.object(manager, "delete_file", new_callable=AsyncMock) as mock_delete,
|
||||
):
|
||||
result = await manager.write_file(
|
||||
@@ -148,6 +158,11 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
|
||||
patch.object(manager, "delete_file", new_callable=AsyncMock),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Unable to overwrite.*concurrent write"):
|
||||
@@ -156,3 +171,346 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
|
||||
)
|
||||
|
||||
mock_storage.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_quota_exceeded_raises_value_error(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""write_file raises ValueError when workspace storage quota is exceeded."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
|
||||
) as mock_scan,
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024, # 250 MB limit
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=250 * 1024 * 1024, # already at limit
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=b"hello")
|
||||
|
||||
# Quota rejection should short-circuit before expensive virus scan
|
||||
mock_scan.assert_not_called()
|
||||
# Storage should NOT have been written to
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_rejects_upload_when_usage_already_exceeds_downgraded_limit(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Downgrading below current usage should block further uploads until usage drops."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
|
||||
) as mock_scan,
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=300 * 1024 * 1024,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=b"hello")
|
||||
|
||||
mock_scan.assert_not_called()
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_80pct_warning_logged(manager, mock_storage, mock_db, caplog):
|
||||
"""write_file logs a warning when workspace usage crosses 80%."""
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit_bytes = 100 # 100 bytes total limit
|
||||
current_usage = 75 # 75 bytes used → 75% before write
|
||||
content = b"123456" # 6 bytes → 81% after write
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit_bytes,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=current_usage,
|
||||
),
|
||||
):
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="backend.util.workspace"):
|
||||
await manager.write_file(filename="test.txt", content=content)
|
||||
|
||||
assert any("workspace storage at" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_not_double_counted(manager, mock_storage, mock_db):
|
||||
"""Overwriting a file subtracts the old file size from usage check."""
|
||||
existing_file = _make_workspace_file(size_bytes=50)
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit_bytes = 100
|
||||
current_usage = 90 # 90 bytes used, 50 of which is the file being replaced
|
||||
content = b"x" * 50 # replacing with same-size file — should succeed
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit_bytes,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=current_usage,
|
||||
),
|
||||
):
|
||||
# Should NOT raise — net usage after overwrite is 90 - 50 + 50 = 90, under 100
|
||||
result = await manager.write_file(
|
||||
filename="test.txt", content=content, overwrite=True
|
||||
)
|
||||
assert result == created_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_zero_limit_bypasses_quota_check(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""When limit is 0 (internal sentinel, not reachable via LD), quota is skipped."""
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=0, # Zero limit → uncapped
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=999_999_999, # Huge existing usage
|
||||
),
|
||||
):
|
||||
# Should NOT raise — zero limit means no enforcement
|
||||
result = await manager.write_file(filename="big.txt", content=b"data")
|
||||
assert result == created_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_exactly_at_limit_is_rejected(manager, mock_storage, mock_db):
|
||||
"""Writing a file that puts usage at exactly the limit should be rejected
|
||||
because projected_usage > storage_limit (not >=)."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
limit = 100
|
||||
current = 95
|
||||
content = b"x" * 6 # 95 + 6 = 101 > 100
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_exactly_at_limit_boundary_succeeds(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Writing a file that puts usage at exactly the limit should succeed
|
||||
because the guard is > not >=."""
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit = 100
|
||||
current = 95
|
||||
content = b"x" * 5 # 95 + 5 = 100 == limit → NOT > limit → passes
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
|
||||
):
|
||||
result = await manager.write_file(filename="test.txt", content=content)
|
||||
assert result == created_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_larger_replacement_rejected(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Replacing a small file with a much larger one near quota is rejected."""
|
||||
existing_file = _make_workspace_file(size_bytes=10)
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
|
||||
limit = 100
|
||||
current = 90 # 90 bytes used, existing file is 10 of those
|
||||
content = b"x" * 25 # net: 90 - 10 + 25 = 105 > 100
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(
|
||||
filename="test.txt", content=content, overwrite=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_smaller_replacement_succeeds(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Replacing a large file with a smaller one near quota succeeds."""
|
||||
existing_file = _make_workspace_file(size_bytes=40)
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
limit = 100
|
||||
current = 90 # 90 bytes used, existing file is 40 of those
|
||||
content = b"x" * 30 # net: 90 - 40 + 30 = 80 < 100
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=limit,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
|
||||
):
|
||||
result = await manager.write_file(
|
||||
filename="test.txt", content=content, overwrite=True
|
||||
)
|
||||
assert result == created_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_quota_rejection_skips_virus_scan_and_storage(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Quota rejection must short-circuit BEFORE expensive virus scan and storage."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
|
||||
) as mock_scan,
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=100,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=100),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=b"data")
|
||||
|
||||
mock_scan.assert_not_called()
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_empty_content_near_limit_succeeds(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Empty file (0 bytes) should always fit even when at the limit."""
|
||||
created_file = _make_workspace_file()
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.return_value = created_file
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=100,
|
||||
),
|
||||
patch("backend.util.workspace.get_workspace_total_size", return_value=100),
|
||||
):
|
||||
result = await manager.write_file(filename="empty.txt", content=b"")
|
||||
assert result == created_file
|
||||
|
||||
@@ -5,6 +5,7 @@ import { cn } from "@/lib/utils";
|
||||
import Link from "next/link";
|
||||
import { formatCents, formatResetTime } from "../usageHelpers";
|
||||
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
|
||||
import { useWorkspaceStorage } from "./useWorkspaceStorage";
|
||||
|
||||
export { formatResetTime };
|
||||
|
||||
@@ -70,6 +71,65 @@ function UsageBar({
|
||||
);
|
||||
}
|
||||
|
||||
export function formatBytes(bytes: number): string {
|
||||
const KB = 1024;
|
||||
const MB = KB * 1024;
|
||||
const GB = MB * 1024;
|
||||
if (bytes < KB) return `${bytes} B`;
|
||||
if (bytes < MB) {
|
||||
const kb = Math.round(bytes / KB);
|
||||
return kb >= 1024 ? `${(bytes / MB).toFixed(1)} MB` : `${kb} KB`;
|
||||
}
|
||||
if (bytes < GB) {
|
||||
const mb = Math.round(bytes / MB);
|
||||
return mb >= 1024 ? `${(bytes / GB).toFixed(1)} GB` : `${mb} MB`;
|
||||
}
|
||||
return `${(bytes / GB).toFixed(1)} GB`;
|
||||
}
|
||||
|
||||
function StorageBar({
|
||||
usedBytes,
|
||||
limitBytes,
|
||||
fileCount,
|
||||
}: {
|
||||
usedBytes: number;
|
||||
limitBytes: number;
|
||||
fileCount: number;
|
||||
}) {
|
||||
if (limitBytes <= 0) return null;
|
||||
|
||||
const rawPercent = (usedBytes / limitBytes) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
usedBytes > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">
|
||||
File storage
|
||||
</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
{formatBytes(usedBytes)} of {formatBytes(limitBytes)} ·{" "}
|
||||
{fileCount} {fileCount === 1 ? "file" : "files"}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(usedBytes > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ResetButton({
|
||||
cost,
|
||||
onCreditChange,
|
||||
@@ -94,6 +154,19 @@ function ResetButton({
|
||||
);
|
||||
}
|
||||
|
||||
function WorkspaceStorageSection() {
|
||||
const { data: storage } = useWorkspaceStorage();
|
||||
if (!storage || storage.limit_bytes <= 0) return null;
|
||||
|
||||
return (
|
||||
<StorageBar
|
||||
usedBytes={storage.used_bytes}
|
||||
limitBytes={storage.limit_bytes}
|
||||
fileCount={storage.file_count}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showHeader = true,
|
||||
@@ -119,9 +192,12 @@ export function UsagePanelContent({
|
||||
|
||||
if (!daily && !weekly) {
|
||||
return (
|
||||
<Text as="span" variant="small" className="text-neutral-500">
|
||||
No usage limits configured
|
||||
</Text>
|
||||
<div className="flex flex-col gap-3">
|
||||
<Text as="span" variant="small" className="text-neutral-500">
|
||||
No usage limits configured
|
||||
</Text>
|
||||
<WorkspaceStorageSection />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -163,6 +239,7 @@ export function UsagePanelContent({
|
||||
size={size}
|
||||
/>
|
||||
)}
|
||||
<WorkspaceStorageSection />
|
||||
{isDailyExhausted &&
|
||||
!isWeeklyExhausted &&
|
||||
resetCost > 0 &&
|
||||
|
||||
@@ -4,8 +4,8 @@ import {
|
||||
cleanup,
|
||||
fireEvent,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsagePanelContent } from "../UsagePanelContent";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsagePanelContent, formatBytes } from "../UsagePanelContent";
|
||||
import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic";
|
||||
|
||||
const mockResetUsage = vi.fn();
|
||||
@@ -13,9 +13,20 @@ vi.mock("../../../hooks/useResetRateLimit", () => ({
|
||||
useResetRateLimit: () => ({ resetUsage: mockResetUsage, isPending: false }),
|
||||
}));
|
||||
|
||||
const mockStorageData = vi.fn();
|
||||
vi.mock("../useWorkspaceStorage", () => ({
|
||||
useWorkspaceStorage: () => mockStorageData(),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockResetUsage.mockReset();
|
||||
mockStorageData.mockReset();
|
||||
});
|
||||
|
||||
// Default: no storage data (most existing tests don't need it)
|
||||
beforeEach(() => {
|
||||
mockStorageData.mockReturnValue({ data: undefined });
|
||||
});
|
||||
|
||||
function makeUsage(
|
||||
@@ -47,6 +58,48 @@ function makeUsage(
|
||||
} as CoPilotUsagePublic;
|
||||
}
|
||||
|
||||
describe("formatBytes", () => {
|
||||
it.each([
|
||||
[0, "0 B"],
|
||||
[512, "512 B"],
|
||||
[1024, "1 KB"],
|
||||
[250 * 1024, "250 KB"],
|
||||
[1023 * 1024, "1023 KB"],
|
||||
[1000 * 1024, "1000 KB"],
|
||||
[1024 * 1024, "1 MB"],
|
||||
[250 * 1024 * 1024, "250 MB"],
|
||||
[1000 * 1024 * 1024, "1000 MB"],
|
||||
[1024 * 1024 * 1024, "1.0 GB"],
|
||||
[5 * 1024 * 1024 * 1024, "5.0 GB"],
|
||||
[15 * 1024 * 1024 * 1024, "15.0 GB"],
|
||||
])("formats %d bytes as %s", (input, expected) => {
|
||||
expect(formatBytes(input)).toBe(expected);
|
||||
});
|
||||
|
||||
// Adversarial edge cases
|
||||
it("handles 1 byte", () => {
|
||||
expect(formatBytes(1)).toBe("1 B");
|
||||
});
|
||||
|
||||
it("handles exactly 1023 bytes (just under 1 KB)", () => {
|
||||
expect(formatBytes(1023)).toBe("1023 B");
|
||||
});
|
||||
|
||||
it("auto-promotes 1 MB - 1 byte to MB (rounds up to 1024 KB → 1.0 MB)", () => {
|
||||
// 1048575 / 1024 = 1023.999 → Math.round = 1024 → kb >= 1024 → promotes to MB
|
||||
expect(formatBytes(1048575)).toBe("1.0 MB");
|
||||
});
|
||||
|
||||
it("auto-promotes 1 GB - 1 byte to GB (rounds up to 1024 MB → 1.0 GB)", () => {
|
||||
// 1073741823 / (1024*1024) = 1023.999 → Math.round = 1024 → promotes to GB
|
||||
expect(formatBytes(1073741823)).toBe("1.0 GB");
|
||||
});
|
||||
|
||||
it("handles very large values (1 TB)", () => {
|
||||
expect(formatBytes(1024 * 1024 * 1024 * 1024)).toBe("1024.0 GB");
|
||||
});
|
||||
});
|
||||
|
||||
describe("UsagePanelContent", () => {
|
||||
it("renders 'No usage limits configured' when both windows are null", () => {
|
||||
render(
|
||||
@@ -57,6 +110,26 @@ describe("UsagePanelContent", () => {
|
||||
expect(screen.getByText("No usage limits configured")).toBeDefined();
|
||||
});
|
||||
|
||||
it("still renders file storage when usage windows are null", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 100 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 40,
|
||||
file_count: 5,
|
||||
},
|
||||
});
|
||||
|
||||
render(
|
||||
<UsagePanelContent
|
||||
usage={makeUsage({ dailyPercent: null, weeklyPercent: null })}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("No usage limits configured")).toBeDefined();
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders the reset button when daily limit is exhausted", () => {
|
||||
render(
|
||||
<UsagePanelContent
|
||||
@@ -109,4 +182,195 @@ describe("UsagePanelContent", () => {
|
||||
render(<UsagePanelContent usage={makeUsage({ dailyPercent: 0.3 })} />);
|
||||
expect(screen.getByText("<1% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders file storage bar when workspace data is available", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 100 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 40,
|
||||
file_count: 5,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
expect(screen.getByText(/100 MB of 250 MB/)).toBeDefined();
|
||||
expect(screen.getByText(/5 files/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides file storage bar when no workspace data", () => {
|
||||
mockStorageData.mockReturnValue({ data: undefined });
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.queryByText("File storage")).toBeNull();
|
||||
});
|
||||
|
||||
it("hides file storage bar when limit is zero", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 0,
|
||||
limit_bytes: 0,
|
||||
used_percent: 0,
|
||||
file_count: 0,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.queryByText("File storage")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows orange bar when storage usage is at or above 80%", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 210 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 84,
|
||||
file_count: 3,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
expect(screen.getByText("84% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows singular 'file' for single file", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 0,
|
||||
file_count: 1,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText(/1 file$/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows storage '<1% used' when usage is tiny", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 100,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 0.001,
|
||||
file_count: 1,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders header with tier label", () => {
|
||||
render(<UsagePanelContent usage={makeUsage({ tier: "PRO" })} />);
|
||||
expect(screen.getByText("Pro plan")).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides header when showHeader is false", () => {
|
||||
render(<UsagePanelContent usage={makeUsage()} showHeader={false} />);
|
||||
expect(screen.queryByText("Usage limits")).toBeNull();
|
||||
});
|
||||
|
||||
// Adversarial edge cases
|
||||
|
||||
it("hides storage bar when limit is negative", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 100,
|
||||
limit_bytes: -1,
|
||||
used_percent: 0,
|
||||
file_count: 1,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.queryByText("File storage")).toBeNull();
|
||||
});
|
||||
|
||||
it("handles storage at exactly 100% used", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 250 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 100,
|
||||
file_count: 10,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
expect(screen.getByText(/250 MB of 250 MB/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("clamps storage above 100% to 100% display", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 300 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 120,
|
||||
file_count: 15,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
// Should show "100% used", not "120% used"
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
});
|
||||
|
||||
it("handles zero files with zero usage", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 0,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 0,
|
||||
file_count: 0,
|
||||
},
|
||||
});
|
||||
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
expect(screen.getByText("0% used")).toBeDefined();
|
||||
expect(screen.getByText(/0 files/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders billing link by default", () => {
|
||||
render(<UsagePanelContent usage={makeUsage()} />);
|
||||
expect(screen.getByText("Learn more about usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides billing link when showBillingLink is false", () => {
|
||||
render(<UsagePanelContent usage={makeUsage()} showBillingLink={false} />);
|
||||
expect(screen.queryByText("Learn more about usage limits")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders only daily bar when weekly is null", () => {
|
||||
render(
|
||||
<UsagePanelContent
|
||||
usage={makeUsage({ dailyPercent: 50, weeklyPercent: null })}
|
||||
/>,
|
||||
);
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.queryByText("This week")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders only weekly bar when daily is null", () => {
|
||||
render(
|
||||
<UsagePanelContent
|
||||
usage={makeUsage({ dailyPercent: null, weeklyPercent: 30 })}
|
||||
/>,
|
||||
);
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not show tier label when tier is missing", () => {
|
||||
const usage = makeUsage();
|
||||
(usage as Record<string, unknown>).tier = null;
|
||||
render(<UsagePanelContent usage={usage} />);
|
||||
expect(screen.queryByText(/plan$/)).toBeNull();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import { useGetWorkspaceStorageUsage } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import type { StorageUsageResponse } from "@/app/api/__generated__/models/storageUsageResponse";
|
||||
|
||||
export function useWorkspaceStorage() {
|
||||
return useGetWorkspaceStorageUsage({
|
||||
query: {
|
||||
select: (res) => res.data as StorageUsageResponse,
|
||||
staleTime: 30000,
|
||||
refetchInterval: 60000,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -40,7 +40,7 @@ function CoPilotUsageSection() {
|
||||
|
||||
return (
|
||||
<div className="my-6 space-y-4">
|
||||
<h3 className="text-lg font-medium">AutoPilot Usage Limits</h3>
|
||||
<h3 className="text-lg font-medium">AutoPilot Usage & Storage</h3>
|
||||
<div className="rounded-lg border border-neutral-200 p-4">
|
||||
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||
</div>
|
||||
|
||||
@@ -40,8 +40,18 @@ export async function uploadFileDirect(
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const detail = await res.text().catch(() => res.statusText);
|
||||
throw new Error(`Upload failed (${res.status}): ${detail}`);
|
||||
let message: string;
|
||||
try {
|
||||
const body = await res.json();
|
||||
// Backend returns { detail: "..." } or { detail: { message: "..." } }
|
||||
message =
|
||||
typeof body.detail === "string"
|
||||
? body.detail
|
||||
: (body.detail?.message ?? res.statusText);
|
||||
} catch {
|
||||
message = res.statusText;
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
return res.json();
|
||||
|
||||
Reference in New Issue
Block a user