Compare commits

..

2 Commits

Author SHA1 Message Date
Zamil Majdy
7cc1edc61f dx(pr-polish): use --json bucket instead of awk text-column parsing (#12951)
## Why

`/pr-polish` was prematurely emitting `CLEAN-POLL` while CI was still
pending, because the polish-polling loop's CI gate parsed `gh pr checks
$PR` text columns with `awk '{print $2}'`. That works fine for plain job
names, but breaks on jobs with spaces or parens like `test (3.11)`,
`Analyze (python)`, where column 2 is the version `(3.11)` — so `grep -q
"pending"` matched on column 2 of OTHER rows but missed the actual
pending entries. Real symptom on PR #12948: the orchestrator reported
`ORCHESTRATOR:DONE` while `test (3.11/3.12/3.13)` and `Check PR Status`
were still running.

## What

Add a "Concrete CI fetch" subsection right after the polish-polling
pseudocode block, showing the `--json bucket` shape that bypasses the
column-parsing trap entirely. Also flag the `bucket` vs `conclusion`
gotcha (the REST API uses `conclusion`; `gh pr checks --json` only
exposes `bucket`).

## How

Surgical additive edit — the existing pseudocode + state machine is
preserved; the new subsection just translates the abstract
`fetch_check_runs(PR)` into a concrete one-liner so the next implementer
doesn't reach for `awk` again.

## Test plan

- [x] Verified the regression against PR #12948: bucket-based polling
correctly identified 4 pending checks the awk path missed
- [x] Confirmed `gh pr checks {N} --json conclusion` errors with
`Unknown JSON field: "conclusion"` (this gotcha is now noted in the
skill)
2026-04-30 12:44:25 +07:00
Zamil Majdy
4a1741cc15 fix(platform): cancel-banner copy + clearer 422 on currency mismatch (#12947)
## Why

Two regressions surfaced after
[#12933](https://github.com/Significant-Gravitas/AutoGPT/pull/12933)
merged to `dev`:

1. **Cancel-pending banner shows wrong copy.** The merged PR moved
cancel-at-period-end from `BASIC` → `NO_TIER`, but
`PendingChangeBanner.isCancellation` was still keyed on `"BASIC"`. As a
result, a user who cancels their sub now sees *"Scheduled to downgrade
to No subscription on …"* instead of the intended *"Scheduled to cancel
your subscription on …"*. Caught by Sentry on the merged PR.

2. **Currency-mismatch downgrade returns 502 (looks like outage).** A
user with an existing GBP-active sub (Max Price has
`currency_options.gbp`) tried to downgrade to Pro and got 502. The
backend logs show:
   ```
stripe._error.InvalidRequestError: The price specified only supports
`usd`.
   This doesn't match the expected currency: `gbp`.
   ```
The Pro Price is USD-only; Stripe rejects `SubscriptionSchedule.modify`
because phases must share currency. Wrapping that in a generic 502 hid
the real cause and made it read like a Stripe outage.

## What

* Frontend: flip `PendingChangeBanner.isCancellation` from `pendingTier
=== "BASIC"` to `"NO_TIER"`. Update both component and page-level tests
that exercised the cancellation branch.
* Backend: catch `stripe.InvalidRequestError` whose message mentions
`currency` in `update_subscription_tier`, and return **422** with *"Tier
change unavailable for your current billing currency. Cancel your
subscription and re-subscribe at the target tier, or contact support."*
— so users see the actual reason, not a misleading outage message. Other
`StripeError` paths still return 502.
* New backend test asserts the currency-mismatch branch returns 422 with
the new copy.

## How

* `PendingChangeBanner.tsx` line 28: 1-char change (`"BASIC"` →
`"NO_TIER"`).
* `subscription_routes_test.py` and `PendingChangeBanner.test.tsx`
updated to use `NO_TIER` for the cancellation fixture.
* `v1.py` `update_subscription_tier` adds a typed `except
stripe.InvalidRequestError` branch ahead of the generic `StripeError`;
only currency-mismatch messages get the special 422, everything else
falls through to the existing 502.

## The real fix lives in Stripe config

The defensive 422 here is just a clearer error surface. To actually
unblock GBP/EUR users from changing tiers, the per-tier Stripe Prices
(Pro, and Basic if priced) need `currency_options` for GBP added — Max
already has this, which is why Max checkout shows the £/$ toggle. Stripe
locks `currency_options` after a Price has been transacted, so the
procedure is: create a new Price with USD + GBP from the start → update
the `stripe-price-ids` LD flag to the new Price ID. No further code
change required; same Price ID stays per tier, multiple currencies
inside it.

## Checklist

- [x] Component test for new banner copy
- [x] Backend test for 422 currency-mismatch branch
- [x] Format / lint / types pass
- [x] No protected route added — N/A
2026-04-30 10:25:02 +07:00
22 changed files with 270 additions and 1498 deletions

View File

@@ -160,6 +160,24 @@ while clean_polls < required_clean:
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
### Concrete CI fetch (don't parse `gh pr checks` text columns)
The `fetch_check_runs(PR)` step above must use `--json`, not the default text output. Job names can contain spaces and parentheses (e.g. `test (3.11)`, `Analyze (python)`), so `gh pr checks $PR | awk '{print $2}'` extracts `(3.11)` instead of the status — leading to a clean-poll firing while jobs are still pending.
```bash
# Reliable: use --json so columns are unambiguous.
ci_json=$(gh pr checks $PR --repo Significant-Gravitas/AutoGPT --json name,state,bucket)
pending=$(echo "$ci_json" | jq '[.[] | select(.bucket == "pending")] | length')
failed=$(echo "$ci_json" | jq '[.[] | select(.bucket == "fail" or .bucket == "cancel")] | length')
# Buckets are: pass | fail | pending | cancel | skipping
# (NOTE: gh pr checks does NOT expose `conclusion` as a JSON field —
# only `bucket`. Don't confuse with the GitHub REST API's check_runs
# endpoint, which DOES use conclusion.)
```
Map back to the pseudocode above: `bucket == "pending"` is `ci.conclusion is None (still in_progress)`; `bucket in {"fail", "cancel"}` is `ci.conclusion in NON_SUCCESS_TERMINAL`; `bucket in {"pass", "skipping"}` is clean.
### Why 2 clean polls, not 1
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
@@ -196,6 +214,18 @@ The child skill returning is a **loop iteration boundary**, not a conversation t
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
### **Run /pr-polish in the foreground — never in a background agent**
Spawning `/pr-polish` inside an `Agent(subagent_type="general-purpose")` background task **does not work**. Background agents don't inherit the parent's slash-command registry, so `Skill(skill="pr-review")` and `Skill(skill="pr-address")` calls aren't available — the agent has to manually replicate the child skills' logic, which is fragile and tends to stall on the first network or rate-limit hiccup. Symptom: the background task reports `stalled: no progress for 600s` mid-review.
Run `/pr-polish` inline in the foreground conversation. If the user asks for "/pr-polish + /pr-test in parallel", split them: foreground `/pr-polish`, and ONLY then can the test step go to a background agent (because `/pr-test` doesn't itself need to invoke skills).
### **You MUST invoke `Skill(pr-review)` every round — even when bot reviews already exist**
A common failure mode: CodeRabbit / autogpt-reviewer / Sentry have already posted findings on the PR, and the orchestrator skips the `Skill(pr-review)` step on the assumption that "review has been done." That's wrong — the outer loop's purpose is to layer **the agent's own review** on top of the bot reviews, catching issues the bots miss (architecture, naming, cross-file invariants, hidden coupling). If the orchestrator only addresses bot findings without ever running its own review, the loop converges to "bot-clean" but not "agent-reviewed-clean," and the user reasonably asks "did /pr-polish even read the diff?"
**Self-check before reporting `ORCHESTRATOR:DONE`:** confirm at least one `Skill(skill="pr-review")` call appears in the current orchestration. If none, the loop is incomplete — go back and run one round.
## GitHub rate limits
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:

View File

@@ -416,6 +416,98 @@ def test_update_subscription_tier_paid_requires_urls(
assert response.status_code == 422
def test_update_subscription_tier_currency_mismatch_returns_422(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Stripe rejects a SubscriptionSchedule whose phases mix currencies (e.g.
GBP-checkout sub trying to schedule a USD-only target Price). The handler
must convert that into a specific 422 instead of the generic 502 so the
caller can tell the difference between a currency-config bug and a Stripe
outage."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.MAX
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
side_effect=stripe.InvalidRequestError(
"The price specified only supports `usd`. This doesn't match the"
" expected currency: `gbp`.",
param="phases",
),
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert "billing currency" in detail.lower()
assert "contact support" in detail.lower()
def test_update_subscription_tier_non_currency_invalid_request_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Locks the contract that *only* currency-mismatch InvalidRequestErrors
translate to 422 — every other Stripe InvalidRequestError must still
surface as the generic 502 so that widening the conditional later is
caught by the suite."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.MAX
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
side_effect=stripe.InvalidRequestError(
"No such price: 'price_does_not_exist'",
param="items[0][price]",
),
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 502
assert "billing currency" not in response.json()["detail"].lower()
def test_update_subscription_tier_creates_checkout(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,

View File

@@ -1003,6 +1003,35 @@ async def update_subscription_tier(
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.InvalidRequestError as e:
# Stripe rejects schedule modify when phases mix currencies, e.g. the
# active sub was checked out in GBP but the target tier's Price is
# USD-only. 502 reads as outage; surface a 422 with a specific message
# so the user/admin can see what to fix in Stripe.
msg = str(e)
if "currency" in msg.lower():
logger.warning(
"Currency mismatch on tier change for user %s: %s", user_id, msg
)
raise HTTPException(
status_code=422,
detail=(
"Tier change unavailable for your current billing currency."
" Please contact support — the target tier needs to be"
" configured for your currency in Stripe before this"
" change can go through."
),
)
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e

View File

@@ -2,7 +2,6 @@
Workspace API routes for managing user file storage.
"""
import asyncio
import logging
import os
import re
@@ -15,8 +14,6 @@ from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel, Field
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
@@ -27,7 +24,8 @@ from backend.data.workspace import (
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.workspace import WorkspaceManager, format_bytes
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
@@ -251,28 +249,50 @@ async def upload_file(
# Get or create workspace
workspace = await get_or_create_workspace(user_id)
# Write file via WorkspaceManager (handles virus scan, per-file size,
# and per-user tier-based storage quota internally).
# Pre-write storage cap check (soft check — final enforcement is post-write)
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
current_usage = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
used_percent = (current_usage / storage_limit_bytes) * 100
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded",
"used_bytes": current_usage,
"limit_bytes": storage_limit_bytes,
"used_percent": round(used_percent, 1),
},
)
# Warn at 80% usage
if (
storage_limit_bytes
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
):
logger.warning(
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
)
# Virus scan
await scan_content_safe(content, filename=filename)
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
)
except VirusDetectedError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e)) from e
except VirusScanError as e:
raise fastapi.HTTPException(status_code=500, detail=str(e)) from e
except ValueError as e:
# write_file raises ValueError for path-conflict, size-limit, and
# storage-quota cases; map each to its correct HTTP status.
# write_file raises ValueError for both path-conflict and size-limit
# cases; map each to its correct HTTP status.
message = str(e)
if message.startswith(("File too large", "Storage limit exceeded")):
if message.startswith("File too large"):
raise fastapi.HTTPException(status_code=413, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=message) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
storage_limit_bytes = await get_workspace_storage_limit_bytes(user_id)
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
try:
@@ -284,12 +304,11 @@ async def upload_file(
)
raise fastapi.HTTPException(
status_code=413,
detail=(
f"Storage limit exceeded. "
f"You've used {format_bytes(new_total)} of your "
f"{format_bytes(storage_limit_bytes)} quota. "
f"Delete some files or upgrade your plan for more storage."
),
detail={
"message": "Storage limit exceeded (concurrent upload)",
"used_bytes": new_total,
"limit_bytes": storage_limit_bytes,
},
)
return UploadFileResponse(
@@ -312,13 +331,12 @@ async def get_storage_usage(
"""
Get storage usage information for the user's workspace.
"""
config = Config()
workspace = await get_or_create_workspace(user_id)
used_bytes, file_count, limit_bytes = await asyncio.gather(
get_workspace_total_size(workspace.id),
count_workspace_files(workspace.id),
get_workspace_storage_limit_bytes(user_id),
)
used_bytes = await get_workspace_total_size(workspace.id)
file_count = await count_workspace_files(workspace.id)
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
return StorageUsageResponse(
used_bytes=used_bytes,

View File

@@ -151,16 +151,15 @@ def test_list_files_null_metadata_coerced_to_empty_dict(
# -- upload_file metadata tests --
@patch("backend.api.features.workspace.routes.get_workspace_storage_limit_bytes")
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_passes_user_upload_origin_metadata(
mock_manager_cls, mock_total_size, mock_get_workspace, mock_storage_limit
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
mock_storage_limit.return_value = 250 * 1024 * 1024
written = _make_file(id="new-file", name="doc.pdf")
mock_instance = AsyncMock()
mock_instance.write_file.return_value = written
@@ -179,9 +178,10 @@ def test_upload_passes_user_upload_origin_metadata(
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_returns_409_on_file_conflict(
mock_manager_cls, mock_total_size, mock_get_workspace
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
@@ -234,8 +234,8 @@ def test_upload_happy_path(mocker):
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
@@ -256,24 +256,20 @@ def test_upload_exceeds_max_file_size(mocker):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
cfg.return_value.max_workspace_storage_mb = 500
response = _upload(content=b"x" * 1024)
assert response.status_code == 413
def test_upload_storage_quota_exceeded(mocker):
"""WorkspaceManager.write_file raises ValueError when quota exceeded → 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("Storage limit exceeded: 500 MB used of 250 MB (200.0%)")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
)
response = _upload()
@@ -287,14 +283,13 @@ def test_upload_post_write_quota_race(mocker):
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
# Post-write total exceeds the tier-based limit (250 MB for FREE).
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=600 * 1024 * 1024,
side_effect=[0, 600 * 1024 * 1024],
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
@@ -323,8 +318,8 @@ def test_upload_any_extension(mocker):
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
@@ -338,17 +333,23 @@ def test_upload_any_extension(mocker):
def test_upload_blocked_by_virus_scan(mocker):
"""Files flagged by ClamAV should be rejected via WorkspaceManager."""
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -356,6 +357,7 @@ def test_upload_blocked_by_virus_scan(mocker):
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
mock_manager.write_file.assert_not_called()
def test_upload_file_without_extension(mocker):
@@ -369,8 +371,8 @@ def test_upload_file_without_extension(mocker):
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
@@ -400,8 +402,8 @@ def test_upload_strips_path_components(mocker):
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
@@ -486,6 +488,14 @@ def test_upload_write_file_too_large_returns_413(mocker):
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
@@ -506,6 +516,14 @@ def test_upload_write_file_conflict_returns_409(mocker):
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
@@ -584,54 +602,6 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
)
def test_upload_virus_scan_infrastructure_error_returns_500(mocker):
"""VirusScanError (ClamAV outage) should return 500, not 409."""
from backend.api.features.store.exceptions import VirusScanError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=VirusScanError("ClamAV connection refused"),
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 500
def test_get_storage_usage_returns_tier_based_limit(mocker):
"""get_storage_usage should return the user's tier-based limit, not a static config."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=100 * 1024 * 1024, # 100 MB used
)
mocker.patch(
"backend.api.features.workspace.routes.count_workspace_files",
return_value=5,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage_limit_bytes",
return_value=1024 * 1024 * 1024, # 1 GB (PRO tier)
)
response = client.get("/storage/usage")
assert response.status_code == 200
data = response.json()
assert data["limit_bytes"] == 1024 * 1024 * 1024
assert data["used_bytes"] == 100 * 1024 * 1024
assert data["file_count"] == 5
# -- _sanitize_filename_for_header tests --

View File

@@ -111,19 +111,6 @@ TIER_MULTIPLIERS = _DEFAULT_TIER_MULTIPLIERS
DEFAULT_TIER = SubscriptionTier.NO_TIER
# Per-tier workspace storage caps in MB. NO_TIER keeps the same baseline as
# BASIC so users who cancel retain a small quota and see a real overage cap,
# while LaunchDarkly can still tune tiers without a deploy.
_DEFAULT_TIER_WORKSPACE_STORAGE_MB: dict[SubscriptionTier, int] = {
SubscriptionTier.NO_TIER: 250, # 250 MB
SubscriptionTier.BASIC: 250, # 250 MB
SubscriptionTier.PRO: 1024, # 1 GB
SubscriptionTier.MAX: 5 * 1024, # 5 GB
SubscriptionTier.BUSINESS: 15 * 1024, # 15 GB
SubscriptionTier.ENTERPRISE: 15 * 1024, # 15 GB
}
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
async def _fetch_tier_multipliers_flag() -> dict[SubscriptionTier, float] | None:
"""Fetch the ``copilot-tier-multipliers`` LD flag and parse it.
@@ -245,82 +232,6 @@ async def get_tier_multipliers() -> dict[str, float]:
return {tier.value: multiplier for tier, multiplier in merged.items()}
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
async def _fetch_workspace_storage_limits_flag() -> dict[SubscriptionTier, int] | None:
"""Fetch the ``copilot-tier-workspace-storage-limits`` LD flag and parse it.
Returns a sparse ``{tier: megabytes}`` map built from whichever keys are
valid in the flag payload, or ``None`` when the flag is unset / invalid /
LD is unavailable. Callers merge whatever survives into
:data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
The LD value is expected to be a JSON object keyed by tier enum name
(``{"NO_TIER": 250, "PRO": 1024, "BUSINESS": 15360}``). Non-int,
negative, or zero values are skipped so a broken key degrades to the
code default instead of wiping out the limit. Zero is rejected because
downstream guards treat ``storage_limit == 0`` as "uncapped".
"""
# Lazy import: rate_limit -> feature_flag -> settings -> ... -> rate_limit.
from backend.util.feature_flag import Flag, get_feature_flag_value
raw = await get_feature_flag_value(
Flag.COPILOT_TIER_WORKSPACE_STORAGE_LIMITS.value, "system", None
)
if raw is None:
return None
if not isinstance(raw, dict):
logger.warning(
"Invalid LD value for copilot-tier-workspace-storage-limits "
"(expected JSON object): %r",
raw,
)
return None
parsed: dict[SubscriptionTier, int] = {}
for key, value in raw.items():
try:
tier = SubscriptionTier(key)
except ValueError:
continue
if isinstance(value, bool) or not isinstance(value, int):
logger.warning(
"Invalid LD value for copilot-tier-workspace-storage-limits[%s]: %r",
key,
value,
)
continue
if value <= 0:
logger.warning(
"Non-positive LD value for copilot-tier-workspace-storage-limits[%s]: %r",
key,
value,
)
continue
parsed[tier] = value
return parsed or None
async def get_workspace_storage_limits_mb() -> dict[str, int]:
"""Return the effective ``{tier_value: megabytes}`` workspace limit map.
Honours the ``copilot-tier-workspace-storage-limits`` LD flag when set;
missing tiers inherit :data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
Unparseable flag values or LD fetch failures fall back to the defaults.
"""
try:
override = await _fetch_workspace_storage_limits_flag()
except Exception:
logger.warning(
"get_workspace_storage_limits_mb: LD lookup failed", exc_info=True
)
override = None
merged: dict[SubscriptionTier, int] = dict(_DEFAULT_TIER_WORKSPACE_STORAGE_MB)
if override:
merged.update(override)
return {tier.value: megabytes for tier, megabytes in merged.items()}
class UsageWindow(BaseModel):
"""Usage within a single time window.
@@ -750,19 +661,6 @@ get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-de
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
async def get_workspace_storage_limit_bytes(user_id: str) -> int:
"""Return the workspace storage cap in bytes for the user's subscription tier."""
tier = await get_user_tier(user_id)
limits_mb = await get_workspace_storage_limits_mb()
tier_key = getattr(tier, "value", str(tier))
fallback_mb = limits_mb.get(
DEFAULT_TIER.value,
_DEFAULT_TIER_WORKSPACE_STORAGE_MB[DEFAULT_TIER],
)
mb = limits_mb.get(tier_key, fallback_mb)
return mb * 1024 * 1024
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.

View File

@@ -8,7 +8,6 @@ from redis.exceptions import RedisError
from .rate_limit import (
_DEFAULT_TIER_MULTIPLIERS,
_DEFAULT_TIER_WORKSPACE_STORAGE_MB,
DEFAULT_TIER,
TIER_MULTIPLIERS,
CoPilotUsageStatus,
@@ -19,7 +18,6 @@ from .rate_limit import (
_daily_reset_time,
_fetch_cost_limits_flag,
_fetch_tier_multipliers_flag,
_fetch_workspace_storage_limits_flag,
_weekly_key,
_weekly_reset_time,
acquire_reset_lock,
@@ -29,8 +27,6 @@ from .rate_limit import (
get_tier_multipliers,
get_usage_status,
get_user_tier,
get_workspace_storage_limit_bytes,
get_workspace_storage_limits_mb,
increment_daily_reset_count,
record_cost_usage,
release_reset_lock,
@@ -475,85 +471,6 @@ class TestGetTierMultipliers:
assert result == {t.value: m for t, m in _DEFAULT_TIER_MULTIPLIERS.items()}
class TestGetWorkspaceStorageLimits:
@pytest.fixture(autouse=True)
def _clear_flag_cache(self):
"""Clear the LD flag cache between tests so patches don't leak."""
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_defaults_when_flag_unset(self):
"""With no LD override, the resolver returns the default map."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value=None,
):
result = await get_workspace_storage_limits_mb()
assert result == {
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
}
@pytest.mark.asyncio
async def test_ld_override(self):
"""LD override populates targeted tiers; others inherit defaults."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"NO_TIER": 300, "PRO": 2048},
):
result = await get_workspace_storage_limits_mb()
assert result["NO_TIER"] == 300
assert result["PRO"] == 2048
assert (
result["BASIC"]
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
)
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
@pytest.mark.asyncio
async def test_invalid_json_falls_back(self):
"""A non-object LD value falls back to defaults."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value="broken",
):
result = await get_workspace_storage_limits_mb()
assert result == {
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
}
@pytest.mark.asyncio
async def test_unknown_tier_key_and_invalid_values_skipped(self):
"""Unknown tiers and invalid values degrade to defaults per key."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"NO_TIER": 300, "BOGUS": 999, "MAX": -1, "BUSINESS": "nope"},
):
result = await get_workspace_storage_limits_mb()
assert result["NO_TIER"] == 300
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
assert (
result["BUSINESS"]
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BUSINESS]
)
@pytest.mark.asyncio
async def test_ld_failure_falls_back(self):
"""LD lookup raising propagates to defaults, not up the call stack."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
side_effect=RuntimeError("LD SDK not initialized"),
):
result = await get_workspace_storage_limits_mb()
assert result == {
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
}
# ---------------------------------------------------------------------------
# get_global_rate_limits — LD-flag cost limits parsing
# ---------------------------------------------------------------------------
@@ -1875,284 +1792,3 @@ class TestResetUserUsage:
):
with pytest.raises(RedisError):
await reset_user_usage("user-1")
class TestWorkspaceStorageLimits:
"""Tests for tier-based workspace storage limits."""
@pytest.fixture(autouse=True)
def _clear_flag_cache(self):
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
def test_every_subscription_tier_has_storage_limit(self):
"""Adding a new SubscriptionTier without a storage limit should fail."""
for tier in SubscriptionTier:
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB, (
f"SubscriptionTier.{tier.name} has no entry in "
f"_DEFAULT_TIER_WORKSPACE_STORAGE_MB — add one"
)
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] > 0
def test_every_subscription_tier_has_rate_limit_multiplier(self):
"""Adding a new SubscriptionTier without a rate limit multiplier should fail."""
for tier in SubscriptionTier:
assert tier in _DEFAULT_TIER_MULTIPLIERS, (
f"SubscriptionTier.{tier.name} has no entry in "
f"_DEFAULT_TIER_MULTIPLIERS — add one"
)
if tier == SubscriptionTier.NO_TIER:
assert _DEFAULT_TIER_MULTIPLIERS[tier] == 0.0
else:
assert _DEFAULT_TIER_MULTIPLIERS[tier] > 0
@pytest.mark.parametrize(
"tier,expected_mb",
[
(SubscriptionTier.NO_TIER, 250),
(SubscriptionTier.BASIC, 250),
(SubscriptionTier.PRO, 1024),
(SubscriptionTier.MAX, 5 * 1024),
(SubscriptionTier.BUSINESS, 15 * 1024),
(SubscriptionTier.ENTERPRISE, 15 * 1024),
],
)
def test_tier_workspace_storage_mapping_covers_all_tiers(self, tier, expected_mb):
"""Every tier has an explicit storage limit in the mapping."""
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tier,expected_bytes",
[
(SubscriptionTier.NO_TIER, 250 * 1024 * 1024),
(SubscriptionTier.BASIC, 250 * 1024 * 1024),
(SubscriptionTier.PRO, 1024 * 1024 * 1024),
(SubscriptionTier.MAX, 5 * 1024 * 1024 * 1024),
(SubscriptionTier.BUSINESS, 15 * 1024 * 1024 * 1024),
(SubscriptionTier.ENTERPRISE, 15 * 1024 * 1024 * 1024),
],
)
async def test_get_workspace_storage_limit_bytes_per_tier(
self, tier, expected_bytes
):
"""get_workspace_storage_limit_bytes returns correct bytes for each tier."""
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value=tier,
):
result = await get_workspace_storage_limit_bytes("user-1")
assert result == expected_bytes
@pytest.mark.asyncio
async def test_get_workspace_storage_limit_bytes_defaults_to_default_tier_on_unknown(
self,
):
"""Unknown tier falls back to the default tier limit."""
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value="UNKNOWN_TIER",
):
result = await get_workspace_storage_limit_bytes("user-1")
assert result == 250 * 1024 * 1024
class TestWorkspaceStorageLimitsAdversarial:
"""Adversarial edge-case tests for workspace storage limit resolution."""
@pytest.fixture(autouse=True)
def _clear_flag_cache(self):
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_ld_zero_values_are_rejected(self):
"""LD zero values must be rejected — 0 is the 'uncapped' sentinel downstream."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"BASIC": 0, "PRO": 0},
):
result = await get_workspace_storage_limits_mb()
# Zero rejected → falls back to code defaults
assert (
result["BASIC"]
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
)
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
@pytest.mark.asyncio
async def test_ld_zero_override_does_not_bypass_quota(self):
"""A zero LD override must not result in 0 bytes (which bypasses quota)."""
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
return_value=SubscriptionTier.BASIC,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"BASIC": 0},
),
):
result = await get_workspace_storage_limit_bytes("user-1")
# Zero rejected → falls back to default 250 MB in bytes
assert result == 250 * 1024 * 1024
@pytest.mark.asyncio
async def test_ld_bool_values_are_rejected(self):
"""Booleans masquerading as ints (True=1, False=0) must be rejected."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"PRO": True, "MAX": False},
):
result = await get_workspace_storage_limits_mb()
# PRO and MAX should keep their defaults, not become 1 and 0
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
@pytest.mark.asyncio
async def test_ld_float_values_are_rejected(self):
"""Floats in LD payload should be rejected — must be int."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"PRO": 1024.5},
):
result = await get_workspace_storage_limits_mb()
assert result["PRO"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.PRO]
@pytest.mark.asyncio
async def test_ld_empty_object_falls_back_to_defaults(self):
"""Empty LD payload {} → parsed is empty → None → defaults."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={},
):
result = await get_workspace_storage_limits_mb()
assert result == {
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
}
@pytest.mark.asyncio
async def test_ld_list_instead_of_dict_falls_back(self):
"""LD returns a list instead of dict → rejected → defaults."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value=[250, 1024, 5120],
):
result = await get_workspace_storage_limits_mb()
assert result == {
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
}
@pytest.mark.asyncio
async def test_ld_extremely_large_values_accepted(self):
"""Very large LD values shouldn't be clipped — trust LD config."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={"ENTERPRISE": 999_999_999},
):
result = await get_workspace_storage_limits_mb()
assert result["ENTERPRISE"] == 999_999_999
@pytest.mark.asyncio
async def test_ld_mixed_valid_and_garbage_keys(self):
"""Mixed payload: valid overrides applied, garbage ignored."""
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
return_value={
"PRO": 2048,
"": 100,
"null": 100,
"NO_TIER": -5,
"BASIC": "abc",
"MAX": True,
"ENTERPRISE": 30720,
},
):
result = await get_workspace_storage_limits_mb()
# Only PRO and ENTERPRISE should be overridden
assert result["PRO"] == 2048
assert result["ENTERPRISE"] == 30720
# Everything else keeps defaults
assert (
result["BASIC"]
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
)
assert (
result["NO_TIER"]
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.NO_TIER]
)
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
@pytest.mark.asyncio
async def test_tier_transition_pro_to_no_tier_reduces_limit(self):
"""Simulating unsubscribe: PRO user → NO_TIER gets lower limit."""
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
# PRO tier
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value=SubscriptionTier.PRO,
):
pro_limit = await get_workspace_storage_limit_bytes("user-1")
# Unsubscribe → NO_TIER
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
with patch(
"backend.copilot.rate_limit.get_user_tier",
return_value=SubscriptionTier.NO_TIER,
):
no_tier_limit = await get_workspace_storage_limit_bytes("user-1")
assert pro_limit == 1024 * 1024 * 1024 # 1 GB
assert no_tier_limit == 250 * 1024 * 1024 # 250 MB
assert no_tier_limit < pro_limit
@pytest.mark.asyncio
async def test_all_tiers_are_monotonically_increasing(self):
"""Storage limits should be monotonically non-decreasing across tiers."""
tier_order = [
SubscriptionTier.NO_TIER,
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
SubscriptionTier.BUSINESS,
SubscriptionTier.ENTERPRISE,
]
limits = [_DEFAULT_TIER_WORKSPACE_STORAGE_MB[t] for t in tier_order]
for i in range(1, len(limits)):
assert limits[i] >= limits[i - 1], (
f"{tier_order[i].name} ({limits[i]} MB) < "
f"{tier_order[i - 1].name} ({limits[i - 1]} MB)"
)
@pytest.mark.asyncio
async def test_concurrent_flag_fetches_share_cache(self):
"""Multiple concurrent calls should only hit LD once (caching)."""
call_count = 0
async def counting_flag(*args, **kwargs):
nonlocal call_count
call_count += 1
return {"PRO": 2048}
with patch(
"backend.util.feature_flag.get_feature_flag_value",
new_callable=AsyncMock,
side_effect=counting_flag,
):
import asyncio
results = await asyncio.gather(
get_workspace_storage_limits_mb(),
get_workspace_storage_limits_mb(),
get_workspace_storage_limits_mb(),
)
# All results should be identical
assert all(r == results[0] for r in results)
# Flag should have been fetched at most once due to caching
assert call_count <= 1

View File

@@ -8,7 +8,6 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.store.exceptions import VirusDetectedError, VirusScanError
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
@@ -20,6 +19,7 @@ from backend.copilot.context import (
from backend.copilot.model import ChatSession
from backend.copilot.tools.sandbox import make_session_path
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
@@ -837,6 +837,7 @@ class WriteWorkspaceFileTool(BaseTool):
)
try:
await scan_content_safe(content_bytes, filename=filename)
manager = await get_workspace_manager(user_id, session_id)
rec = await manager.write_file(
content=content_bytes,
@@ -894,12 +895,6 @@ class WriteWorkspaceFileTool(BaseTool):
message=msg,
session_id=session_id,
)
except VirusDetectedError as e:
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
return ErrorResponse(message=str(e), session_id=session_id)
except VirusScanError as e:
logger.error(f"Virus scan infrastructure error: {e}", exc_info=True)
return ErrorResponse(message=str(e), session_id=session_id)
except ValueError as e:
return ErrorResponse(message=str(e), session_id=session_id)
except Exception as e:

View File

@@ -623,73 +623,3 @@ async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_
# Normal workspace path must have produced a content response.
assert isinstance(result, WorkspaceFileContentResponse)
assert base64.b64decode(result.content_base64) == fake_content
# ---------------------------------------------------------------------------
# WriteWorkspaceFileTool exception handling (quota, virus, scan errors)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
class TestWriteWorkspaceFileToolErrorHandling:
"""Verify WriteWorkspaceFileTool returns proper ErrorResponse for exceptions."""
async def _execute_write(self, side_effect, setup_test_data):
"""Helper: run WriteWorkspaceFileTool with mocked write_file."""
user = setup_test_data["user"]
session = make_session(user.id)
mock_manager = AsyncMock()
mock_manager.write_file = AsyncMock(side_effect=side_effect)
with patch(
"backend.copilot.tools.workspace_files.get_workspace_manager",
AsyncMock(return_value=mock_manager),
):
tool = WriteWorkspaceFileTool()
return await tool._execute(
user_id=user.id,
session=session,
filename="test.txt",
content="hello",
)
async def test_quota_exceeded_returns_error_response(self, setup_test_data):
"""Quota exceeded (ValueError) → ErrorResponse with storage message."""
result = await self._execute_write(
ValueError("Storage limit exceeded: 250 MB used of 250 MB (100.0%)"),
setup_test_data,
)
assert isinstance(result, ErrorResponse)
assert "Storage limit exceeded" in result.message
async def test_virus_detected_returns_error_response(self, setup_test_data):
"""VirusDetectedError → ErrorResponse with virus message."""
from backend.api.features.store.exceptions import VirusDetectedError
result = await self._execute_write(
VirusDetectedError("Eicar-Test-Signature"),
setup_test_data,
)
assert isinstance(result, ErrorResponse)
assert "Virus detected" in result.message
async def test_virus_scan_error_returns_error_response(self, setup_test_data):
"""VirusScanError (infra failure) → ErrorResponse with scan message."""
from backend.api.features.store.exceptions import VirusScanError
result = await self._execute_write(
VirusScanError("ClamAV connection refused"),
setup_test_data,
)
assert isinstance(result, ErrorResponse)
assert "ClamAV" in result.message
async def test_file_conflict_returns_error_response(self, setup_test_data):
"""File path conflict (ValueError) → ErrorResponse."""
result = await self._execute_write(
ValueError("File already exists at path: /sessions/abc/test.txt"),
setup_test_data,
)
assert isinstance(result, ErrorResponse)
assert "already exists" in result.message

View File

@@ -204,68 +204,6 @@ async def test_sync_subscription_from_stripe_cancelled():
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.NO_TIER)
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_cancelled_applies_no_tier_storage_limit():
"""After unsubscribe takes effect, workspace storage resolves against NO_TIER."""
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
mock_user = _make_user(tier=SubscriptionTier.PRO)
stripe_sub = {
"id": "sub_old",
"customer": "cus_123",
"status": "canceled",
"items": {"data": []},
}
empty_list = MagicMock()
empty_list.data = []
empty_list.has_more = False
async def _set_tier(_user_id: str, tier: SubscriptionTier) -> None:
mock_user.subscriptionTier = tier
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.stripe.Subscription.list",
return_value=empty_list,
),
patch(
"backend.data.credit.set_subscription_tier",
new_callable=AsyncMock,
side_effect=_set_tier,
),
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
side_effect=lambda _user_id: mock_user.subscriptionTier,
),
patch(
"backend.copilot.rate_limit.get_workspace_storage_limits_mb",
new_callable=AsyncMock,
return_value={
"NO_TIER": 250,
"BASIC": 500,
"PRO": 1024,
"MAX": 5 * 1024,
"BUSINESS": 15 * 1024,
"ENTERPRISE": 15 * 1024,
},
),
patch.object(
get_pending_subscription_change,
"cache_delete",
) as mock_pending_cache_delete,
):
await sync_subscription_from_stripe(stripe_sub)
result = await get_workspace_storage_limit_bytes("user-1")
assert result == 250 * 1024 * 1024
mock_pending_cache_delete.assert_called_once_with("user-1")
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_cancelled_but_other_active_sub_exists():
"""Cancelling sub_old must NOT downgrade the user if sub_new is still active.

View File

@@ -44,7 +44,6 @@ class Flag(str, Enum):
COPILOT_SDK = "copilot-sdk"
COPILOT_COST_LIMITS = "copilot-cost-limits"
COPILOT_TIER_MULTIPLIERS = "copilot-tier-multipliers"
COPILOT_TIER_WORKSPACE_STORAGE_LIMITS = "copilot-tier-workspace-storage-limits"
COPILOT_TIER_STRIPE_PRICES = "copilot-tier-stripe-prices"
GRAPHITI_MEMORY = "graphiti-memory"
# Stripe Product ID for top-up Checkout sessions. When unset (default),

View File

@@ -441,6 +441,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum file size in MB for workspace files (1-1024 MB)",
)
max_workspace_storage_mb: int = Field(
default=500,
ge=1,
le=10240,
description="Maximum total workspace storage per user in MB.",
)
# AutoMod configuration
automod_enabled: bool = Field(
default=False,

View File

@@ -5,7 +5,6 @@ This module provides a high-level interface for workspace file operations,
combining the storage backend and database layer.
"""
import asyncio
import logging
import mimetypes
import uuid
@@ -13,28 +12,12 @@ from typing import Optional
from prisma.errors import UniqueViolationError
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
from backend.data.db_accessors import workspace_db
from backend.data.workspace import WorkspaceFile, get_workspace_total_size
from backend.data.workspace import WorkspaceFile
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
def format_bytes(n: int) -> str:
"""Format bytes as a human-readable string (e.g. 250 MB, 1.0 GB)."""
KB, MB, GB = 1024, 1024**2, 1024**3
if n < KB:
return f"{n} B"
if n < MB:
kb = round(n / KB)
return f"{n / MB:.1f} MB" if kb >= 1024 else f"{kb} KB"
if n < GB:
mb = round(n / MB)
return f"{n / GB:.1f} GB" if mb >= 1024 else f"{mb} MB"
return f"{n / GB:.1f} GB"
logger = logging.getLogger(__name__)
@@ -202,7 +185,11 @@ class WorkspaceManager:
f"{Config().max_file_size_mb}MB limit"
)
# Determine path with session scoping (needed before quota check for overwrites)
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
await scan_content_safe(content, filename=filename)
# Determine path with session scoping
if path is None:
path = f"/{filename}"
elif not path.startswith("/"):
@@ -211,37 +198,10 @@ class WorkspaceManager:
# Resolve path with session prefix
path = self._resolve_path(path)
# Enforce per-user workspace storage quota (tier-based).
# For overwrites, subtract the existing file's size so replacing a file
# with a same-size or smaller file is not rejected near the cap.
storage_limit, current_usage = await asyncio.gather(
get_workspace_storage_limit_bytes(self.user_id),
get_workspace_total_size(self.workspace_id),
)
if overwrite:
db = workspace_db()
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
if existing is not None:
current_usage = max(0, current_usage - existing.size_bytes)
projected_usage = current_usage + len(content)
if storage_limit > 0 and projected_usage > storage_limit:
raise ValueError(
f"Storage limit exceeded. "
f"You've used {format_bytes(current_usage)} of your "
f"{format_bytes(storage_limit)} quota. "
f"Delete some files or upgrade your plan for more storage."
)
if storage_limit > 0 and projected_usage / storage_limit >= 0.8:
logger.warning(
f"User {self.user_id} workspace storage at "
f"{projected_usage / storage_limit * 100:.1f}% "
f"({projected_usage} / {storage_limit} bytes)"
)
# Check if file exists at path (only error for non-overwrite case).
# Done before virus scanning so a cheap duplicate-path check isn't
# blocked by a potentially slow or unavailable scanner.
# Check if file exists at path (only error for non-overwrite case)
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
# This ensures the new file is written to storage BEFORE the old one is deleted,
# preventing data loss if the new write fails
db = workspace_db()
if not overwrite:
@@ -249,10 +209,6 @@ class WorkspaceManager:
if existing is not None:
raise ValueError(f"File already exists at path: {path}")
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
await scan_content_safe(content, filename=filename)
# Auto-detect MIME type if not provided
if mime_type is None:
mime_type, _ = mimetypes.guess_type(filename)

View File

@@ -88,11 +88,6 @@ async def test_write_file_no_overwrite_unique_violation_raises_and_cleans_up(
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
):
with pytest.raises(ValueError, match="File already exists"):
await manager.write_file(
@@ -120,11 +115,6 @@ async def test_write_file_overwrite_conflict_then_retry_succeeds(
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
patch.object(manager, "delete_file", new_callable=AsyncMock) as mock_delete,
):
result = await manager.write_file(
@@ -158,11 +148,6 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=0),
patch.object(manager, "delete_file", new_callable=AsyncMock),
):
with pytest.raises(ValueError, match="Unable to overwrite.*concurrent write"):
@@ -171,346 +156,3 @@ async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
)
mock_storage.delete.assert_called_once()
@pytest.mark.asyncio
async def test_write_file_quota_exceeded_raises_value_error(
manager, mock_storage, mock_db
):
"""write_file raises ValueError when workspace storage quota is exceeded."""
mock_db.get_workspace_file_by_path.return_value = None
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch(
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
) as mock_scan,
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024, # 250 MB limit
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=250 * 1024 * 1024, # already at limit
),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(filename="test.txt", content=b"hello")
# Quota rejection should short-circuit before expensive virus scan
mock_scan.assert_not_called()
# Storage should NOT have been written to
mock_storage.store.assert_not_called()
@pytest.mark.asyncio
async def test_write_file_rejects_upload_when_usage_already_exceeds_downgraded_limit(
manager, mock_storage, mock_db
):
"""Downgrading below current usage should block further uploads until usage drops."""
mock_db.get_workspace_file_by_path.return_value = None
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch(
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
) as mock_scan,
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=250 * 1024 * 1024,
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=300 * 1024 * 1024,
),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(filename="test.txt", content=b"hello")
mock_scan.assert_not_called()
mock_storage.store.assert_not_called()
@pytest.mark.asyncio
async def test_write_file_80pct_warning_logged(manager, mock_storage, mock_db, caplog):
"""write_file logs a warning when workspace usage crosses 80%."""
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = None
mock_db.create_workspace_file.return_value = created_file
limit_bytes = 100 # 100 bytes total limit
current_usage = 75 # 75 bytes used → 75% before write
content = b"123456" # 6 bytes → 81% after write
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit_bytes,
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=current_usage,
),
):
import logging
with caplog.at_level(logging.WARNING, logger="backend.util.workspace"):
await manager.write_file(filename="test.txt", content=content)
assert any("workspace storage at" in record.message for record in caplog.records)
@pytest.mark.asyncio
async def test_write_file_overwrite_not_double_counted(manager, mock_storage, mock_db):
"""Overwriting a file subtracts the old file size from usage check."""
existing_file = _make_workspace_file(size_bytes=50)
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = existing_file
mock_db.create_workspace_file.return_value = created_file
limit_bytes = 100
current_usage = 90 # 90 bytes used, 50 of which is the file being replaced
content = b"x" * 50 # replacing with same-size file — should succeed
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit_bytes,
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=current_usage,
),
):
# Should NOT raise — net usage after overwrite is 90 - 50 + 50 = 90, under 100
result = await manager.write_file(
filename="test.txt", content=content, overwrite=True
)
assert result == created_file
@pytest.mark.asyncio
async def test_write_file_zero_limit_bypasses_quota_check(
manager, mock_storage, mock_db
):
"""When limit is 0 (internal sentinel, not reachable via LD), quota is skipped."""
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = None
mock_db.create_workspace_file.return_value = created_file
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=0, # Zero limit → uncapped
),
patch(
"backend.util.workspace.get_workspace_total_size",
return_value=999_999_999, # Huge existing usage
),
):
# Should NOT raise — zero limit means no enforcement
result = await manager.write_file(filename="big.txt", content=b"data")
assert result == created_file
@pytest.mark.asyncio
async def test_write_file_exactly_at_limit_is_rejected(manager, mock_storage, mock_db):
"""Writing a file that puts usage at exactly the limit should be rejected
because projected_usage > storage_limit (not >=)."""
mock_db.get_workspace_file_by_path.return_value = None
limit = 100
current = 95
content = b"x" * 6 # 95 + 6 = 101 > 100
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(filename="test.txt", content=content)
@pytest.mark.asyncio
async def test_write_file_exactly_at_limit_boundary_succeeds(
manager, mock_storage, mock_db
):
"""Writing a file that puts usage at exactly the limit should succeed
because the guard is > not >=."""
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = None
mock_db.create_workspace_file.return_value = created_file
limit = 100
current = 95
content = b"x" * 5 # 95 + 5 = 100 == limit → NOT > limit → passes
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
):
result = await manager.write_file(filename="test.txt", content=content)
assert result == created_file
@pytest.mark.asyncio
async def test_write_file_overwrite_larger_replacement_rejected(
manager, mock_storage, mock_db
):
"""Replacing a small file with a much larger one near quota is rejected."""
existing_file = _make_workspace_file(size_bytes=10)
mock_db.get_workspace_file_by_path.return_value = existing_file
limit = 100
current = 90 # 90 bytes used, existing file is 10 of those
content = b"x" * 25 # net: 90 - 10 + 25 = 105 > 100
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(
filename="test.txt", content=content, overwrite=True
)
@pytest.mark.asyncio
async def test_write_file_overwrite_smaller_replacement_succeeds(
manager, mock_storage, mock_db
):
"""Replacing a large file with a smaller one near quota succeeds."""
existing_file = _make_workspace_file(size_bytes=40)
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = existing_file
mock_db.create_workspace_file.return_value = created_file
limit = 100
current = 90 # 90 bytes used, existing file is 40 of those
content = b"x" * 30 # net: 90 - 40 + 30 = 80 < 100
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=limit,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=current),
):
result = await manager.write_file(
filename="test.txt", content=content, overwrite=True
)
assert result == created_file
@pytest.mark.asyncio
async def test_write_file_quota_rejection_skips_virus_scan_and_storage(
manager, mock_storage, mock_db
):
"""Quota rejection must short-circuit BEFORE expensive virus scan and storage."""
mock_db.get_workspace_file_by_path.return_value = None
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch(
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
) as mock_scan,
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=100,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=100),
):
with pytest.raises(ValueError, match="Storage limit exceeded"):
await manager.write_file(filename="test.txt", content=b"data")
mock_scan.assert_not_called()
mock_storage.store.assert_not_called()
@pytest.mark.asyncio
async def test_write_file_empty_content_near_limit_succeeds(
manager, mock_storage, mock_db
):
"""Empty file (0 bytes) should always fit even when at the limit."""
created_file = _make_workspace_file()
mock_db.get_workspace_file_by_path.return_value = None
mock_db.create_workspace_file.return_value = created_file
with (
patch(
"backend.util.workspace.get_workspace_storage",
return_value=mock_storage,
),
patch("backend.util.workspace.workspace_db", return_value=mock_db),
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
patch(
"backend.util.workspace.get_workspace_storage_limit_bytes",
return_value=100,
),
patch("backend.util.workspace.get_workspace_total_size", return_value=100),
):
result = await manager.write_file(filename="empty.txt", content=b"")
assert result == created_file

View File

@@ -5,7 +5,6 @@ import { cn } from "@/lib/utils";
import Link from "next/link";
import { formatCents, formatResetTime } from "../usageHelpers";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
import { useWorkspaceStorage } from "./useWorkspaceStorage";
export { formatResetTime };
@@ -71,65 +70,6 @@ function UsageBar({
);
}
export function formatBytes(bytes: number): string {
const KB = 1024;
const MB = KB * 1024;
const GB = MB * 1024;
if (bytes < KB) return `${bytes} B`;
if (bytes < MB) {
const kb = Math.round(bytes / KB);
return kb >= 1024 ? `${(bytes / MB).toFixed(1)} MB` : `${kb} KB`;
}
if (bytes < GB) {
const mb = Math.round(bytes / MB);
return mb >= 1024 ? `${(bytes / GB).toFixed(1)} GB` : `${mb} MB`;
}
return `${(bytes / GB).toFixed(1)} GB`;
}
function StorageBar({
usedBytes,
limitBytes,
fileCount,
}: {
usedBytes: number;
limitBytes: number;
fileCount: number;
}) {
if (limitBytes <= 0) return null;
const rawPercent = (usedBytes / limitBytes) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
usedBytes > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">
File storage
</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
{formatBytes(usedBytes)} of {formatBytes(limitBytes)} &middot;{" "}
{fileCount} {fileCount === 1 ? "file" : "files"}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(usedBytes > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
function ResetButton({
cost,
onCreditChange,
@@ -154,19 +94,6 @@ function ResetButton({
);
}
function WorkspaceStorageSection() {
const { data: storage } = useWorkspaceStorage();
if (!storage || storage.limit_bytes <= 0) return null;
return (
<StorageBar
usedBytes={storage.used_bytes}
limitBytes={storage.limit_bytes}
fileCount={storage.file_count}
/>
);
}
export function UsagePanelContent({
usage,
showHeader = true,
@@ -192,12 +119,9 @@ export function UsagePanelContent({
if (!daily && !weekly) {
return (
<div className="flex flex-col gap-3">
<Text as="span" variant="small" className="text-neutral-500">
No usage limits configured
</Text>
<WorkspaceStorageSection />
</div>
<Text as="span" variant="small" className="text-neutral-500">
No usage limits configured
</Text>
);
}
@@ -239,7 +163,6 @@ export function UsagePanelContent({
size={size}
/>
)}
<WorkspaceStorageSection />
{isDailyExhausted &&
!isWeeklyExhausted &&
resetCost > 0 &&

View File

@@ -4,8 +4,8 @@ import {
cleanup,
fireEvent,
} from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { UsagePanelContent, formatBytes } from "../UsagePanelContent";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsagePanelContent } from "../UsagePanelContent";
import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic";
const mockResetUsage = vi.fn();
@@ -13,20 +13,9 @@ vi.mock("../../../hooks/useResetRateLimit", () => ({
useResetRateLimit: () => ({ resetUsage: mockResetUsage, isPending: false }),
}));
const mockStorageData = vi.fn();
vi.mock("../useWorkspaceStorage", () => ({
useWorkspaceStorage: () => mockStorageData(),
}));
afterEach(() => {
cleanup();
mockResetUsage.mockReset();
mockStorageData.mockReset();
});
// Default: no storage data (most existing tests don't need it)
beforeEach(() => {
mockStorageData.mockReturnValue({ data: undefined });
});
function makeUsage(
@@ -58,48 +47,6 @@ function makeUsage(
} as CoPilotUsagePublic;
}
describe("formatBytes", () => {
it.each([
[0, "0 B"],
[512, "512 B"],
[1024, "1 KB"],
[250 * 1024, "250 KB"],
[1023 * 1024, "1023 KB"],
[1000 * 1024, "1000 KB"],
[1024 * 1024, "1 MB"],
[250 * 1024 * 1024, "250 MB"],
[1000 * 1024 * 1024, "1000 MB"],
[1024 * 1024 * 1024, "1.0 GB"],
[5 * 1024 * 1024 * 1024, "5.0 GB"],
[15 * 1024 * 1024 * 1024, "15.0 GB"],
])("formats %d bytes as %s", (input, expected) => {
expect(formatBytes(input)).toBe(expected);
});
// Adversarial edge cases
it("handles 1 byte", () => {
expect(formatBytes(1)).toBe("1 B");
});
it("handles exactly 1023 bytes (just under 1 KB)", () => {
expect(formatBytes(1023)).toBe("1023 B");
});
it("auto-promotes 1 MB - 1 byte to MB (rounds up to 1024 KB → 1.0 MB)", () => {
// 1048575 / 1024 = 1023.999 → Math.round = 1024 → kb >= 1024 → promotes to MB
expect(formatBytes(1048575)).toBe("1.0 MB");
});
it("auto-promotes 1 GB - 1 byte to GB (rounds up to 1024 MB → 1.0 GB)", () => {
// 1073741823 / (1024*1024) = 1023.999 → Math.round = 1024 → promotes to GB
expect(formatBytes(1073741823)).toBe("1.0 GB");
});
it("handles very large values (1 TB)", () => {
expect(formatBytes(1024 * 1024 * 1024 * 1024)).toBe("1024.0 GB");
});
});
describe("UsagePanelContent", () => {
it("renders 'No usage limits configured' when both windows are null", () => {
render(
@@ -110,26 +57,6 @@ describe("UsagePanelContent", () => {
expect(screen.getByText("No usage limits configured")).toBeDefined();
});
it("still renders file storage when usage windows are null", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 100 * 1024 * 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 40,
file_count: 5,
},
});
render(
<UsagePanelContent
usage={makeUsage({ dailyPercent: null, weeklyPercent: null })}
/>,
);
expect(screen.getByText("No usage limits configured")).toBeDefined();
expect(screen.getByText("File storage")).toBeDefined();
});
it("renders the reset button when daily limit is exhausted", () => {
render(
<UsagePanelContent
@@ -182,195 +109,4 @@ describe("UsagePanelContent", () => {
render(<UsagePanelContent usage={makeUsage({ dailyPercent: 0.3 })} />);
expect(screen.getByText("<1% used")).toBeDefined();
});
it("renders file storage bar when workspace data is available", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 100 * 1024 * 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 40,
file_count: 5,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("File storage")).toBeDefined();
expect(screen.getByText(/100 MB of 250 MB/)).toBeDefined();
expect(screen.getByText(/5 files/)).toBeDefined();
});
it("hides file storage bar when no workspace data", () => {
mockStorageData.mockReturnValue({ data: undefined });
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.queryByText("File storage")).toBeNull();
});
it("hides file storage bar when limit is zero", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 0,
limit_bytes: 0,
used_percent: 0,
file_count: 0,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.queryByText("File storage")).toBeNull();
});
it("shows orange bar when storage usage is at or above 80%", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 210 * 1024 * 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 84,
file_count: 3,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("File storage")).toBeDefined();
expect(screen.getByText("84% used")).toBeDefined();
});
it("shows singular 'file' for single file", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 0,
file_count: 1,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText(/1 file$/)).toBeDefined();
});
it("shows storage '<1% used' when usage is tiny", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 100,
limit_bytes: 250 * 1024 * 1024,
used_percent: 0.001,
file_count: 1,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("File storage")).toBeDefined();
});
it("renders header with tier label", () => {
render(<UsagePanelContent usage={makeUsage({ tier: "PRO" })} />);
expect(screen.getByText("Pro plan")).toBeDefined();
});
it("hides header when showHeader is false", () => {
render(<UsagePanelContent usage={makeUsage()} showHeader={false} />);
expect(screen.queryByText("Usage limits")).toBeNull();
});
// Adversarial edge cases
it("hides storage bar when limit is negative", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 100,
limit_bytes: -1,
used_percent: 0,
file_count: 1,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.queryByText("File storage")).toBeNull();
});
it("handles storage at exactly 100% used", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 250 * 1024 * 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 100,
file_count: 10,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("100% used")).toBeDefined();
expect(screen.getByText(/250 MB of 250 MB/)).toBeDefined();
});
it("clamps storage above 100% to 100% display", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 300 * 1024 * 1024,
limit_bytes: 250 * 1024 * 1024,
used_percent: 120,
file_count: 15,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
// Should show "100% used", not "120% used"
expect(screen.getByText("100% used")).toBeDefined();
expect(screen.getByText("File storage")).toBeDefined();
});
it("handles zero files with zero usage", () => {
mockStorageData.mockReturnValue({
data: {
used_bytes: 0,
limit_bytes: 250 * 1024 * 1024,
used_percent: 0,
file_count: 0,
},
});
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("File storage")).toBeDefined();
expect(screen.getByText("0% used")).toBeDefined();
expect(screen.getByText(/0 files/)).toBeDefined();
});
it("renders billing link by default", () => {
render(<UsagePanelContent usage={makeUsage()} />);
expect(screen.getByText("Learn more about usage limits")).toBeDefined();
});
it("hides billing link when showBillingLink is false", () => {
render(<UsagePanelContent usage={makeUsage()} showBillingLink={false} />);
expect(screen.queryByText("Learn more about usage limits")).toBeNull();
});
it("renders only daily bar when weekly is null", () => {
render(
<UsagePanelContent
usage={makeUsage({ dailyPercent: 50, weeklyPercent: null })}
/>,
);
expect(screen.getByText("Today")).toBeDefined();
expect(screen.queryByText("This week")).toBeNull();
});
it("renders only weekly bar when daily is null", () => {
render(
<UsagePanelContent
usage={makeUsage({ dailyPercent: null, weeklyPercent: 30 })}
/>,
);
expect(screen.queryByText("Today")).toBeNull();
expect(screen.getByText("This week")).toBeDefined();
});
it("does not show tier label when tier is missing", () => {
const usage = makeUsage();
(usage as Record<string, unknown>).tier = null;
render(<UsagePanelContent usage={usage} />);
expect(screen.queryByText(/plan$/)).toBeNull();
expect(screen.getByText("Usage limits")).toBeDefined();
});
});

View File

@@ -1,12 +0,0 @@
import { useGetWorkspaceStorageUsage } from "@/app/api/__generated__/endpoints/workspace/workspace";
import type { StorageUsageResponse } from "@/app/api/__generated__/models/storageUsageResponse";
export function useWorkspaceStorage() {
return useGetWorkspaceStorageUsage({
query: {
select: (res) => res.data as StorageUsageResponse,
staleTime: 30000,
refetchInterval: 60000,
},
});
}

View File

@@ -736,24 +736,19 @@ describe("SubscriptionTierSection", () => {
).toBeDefined();
});
it("renders BASIC cancellation copy in banner when pending_tier is BASIC", () => {
it("renders cancellation copy in banner when pending_tier is NO_TIER", () => {
setupMocks({
subscription: makeSubscription({
tier: "MAX",
pendingTier: "BASIC",
// Noon UTC so the local-formatted date lands on the same day
// regardless of the runner's timezone (midnight UTC drifts to
// the prior day in any timezone west of UTC).
pendingTier: "NO_TIER",
pendingTierEffectiveAt: new Date("2026-05-15T12:00:00Z"),
}),
});
render(<SubscriptionTierSection />);
// Cancellation copy — distinct from the generic downgrade phrasing.
expect(
screen.getByText(/scheduled to cancel your subscription on/i),
).toBeDefined();
expect(screen.getByText(/May 15, 2026/)).toBeDefined();
// Must NOT render the "downgrade to" phrasing on BASIC cancellation.
expect(screen.queryByText(/scheduled to downgrade to/i)).toBeNull();
});
});

View File

@@ -25,7 +25,7 @@ export function PendingChangeBanner({
const currentLabel = getTierLabel(currentTier);
const dateText = formatPendingDate(pendingEffectiveAt);
const isCancellation = pendingTier === "BASIC";
const isCancellation = pendingTier === "NO_TIER";
return (
<div

View File

@@ -7,7 +7,7 @@ import { PendingChangeBanner } from "../PendingChangeBanner";
describe("PendingChangeBanner", () => {
const baseProps = {
currentTier: "PRO",
pendingTier: "BASIC",
pendingTier: "NO_TIER",
// Use noon UTC so the formatted local date lands on the same day
// regardless of the host timezone (important for CI runners).
pendingEffectiveAt: "2026-05-01T12:00:00Z",
@@ -25,7 +25,7 @@ describe("PendingChangeBanner", () => {
expect(container.firstChild).toBeNull();
});
it("shows cancellation copy when pending tier is BASIC", () => {
it("shows cancellation copy when pending tier is NO_TIER", () => {
render(<PendingChangeBanner {...baseProps} />);
expect(screen.getByText(/cancel your subscription on/i)).toBeDefined();
expect(screen.getByText("May 1, 2026")).toBeDefined();

View File

@@ -40,7 +40,7 @@ function CoPilotUsageSection() {
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">AutoPilot Usage & Storage</h3>
<h3 className="text-lg font-medium">AutoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>

View File

@@ -40,18 +40,8 @@ export async function uploadFileDirect(
});
if (!res.ok) {
let message: string;
try {
const body = await res.json();
// Backend returns { detail: "..." } or { detail: { message: "..." } }
message =
typeof body.detail === "string"
? body.detail
: (body.detail?.message ?? res.statusText);
} catch {
message = res.statusText;
}
throw new Error(message);
const detail = await res.text().catch(() => res.statusText);
throw new Error(`Upload failed (${res.status}): ${detail}`);
}
return res.json();