mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
feat(backend): pull workspace storage limits from LaunchDarkly
- Add _DEFAULT_TIER_WORKSPACE_STORAGE_MB with explicit NO_TIER entry (250 MB) - Add _fetch_workspace_storage_limits_flag() and get_workspace_storage_limits_mb() mirroring the chat-limit LD pattern - Add Flag.COPILOT_TIER_WORKSPACE_STORAGE_LIMITS enum entry - Update get_workspace_storage_limit_bytes() to use LD-backed map - Add tests: LD resolution, NO_TIER behavior, unsubscribe downgrade, upload rejection when over cap, frontend null-usage-windows rendering Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -111,6 +111,19 @@ TIER_MULTIPLIERS = _DEFAULT_TIER_MULTIPLIERS
|
||||
DEFAULT_TIER = SubscriptionTier.NO_TIER
|
||||
|
||||
|
||||
# Per-tier workspace storage caps in MB. NO_TIER keeps the same baseline as
|
||||
# BASIC so users who cancel retain a small quota and see a real overage cap,
|
||||
# while LaunchDarkly can still tune tiers without a deploy.
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.NO_TIER: 250, # 250 MB
|
||||
SubscriptionTier.BASIC: 250, # 250 MB
|
||||
SubscriptionTier.PRO: 1024, # 1 GB
|
||||
SubscriptionTier.MAX: 5 * 1024, # 5 GB
|
||||
SubscriptionTier.BUSINESS: 15 * 1024, # 15 GB
|
||||
SubscriptionTier.ENTERPRISE: 15 * 1024, # 15 GB
|
||||
}
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
|
||||
async def _fetch_tier_multipliers_flag() -> dict[SubscriptionTier, float] | None:
|
||||
"""Fetch the ``copilot-tier-multipliers`` LD flag and parse it.
|
||||
@@ -232,14 +245,79 @@ async def get_tier_multipliers() -> dict[str, float]:
|
||||
return {tier.value: multiplier for tier, multiplier in merged.items()}
|
||||
|
||||
|
||||
# Per-tier workspace storage caps in MB.
|
||||
TIER_WORKSPACE_STORAGE_MB: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.BASIC: 250, # 250 MB
|
||||
SubscriptionTier.PRO: 1024, # 1 GB
|
||||
SubscriptionTier.MAX: 5 * 1024, # 5 GB
|
||||
SubscriptionTier.BUSINESS: 15 * 1024, # 15 GB
|
||||
SubscriptionTier.ENTERPRISE: 15 * 1024, # 15 GB
|
||||
}
|
||||
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
|
||||
async def _fetch_workspace_storage_limits_flag() -> dict[SubscriptionTier, int] | None:
|
||||
"""Fetch the ``copilot-tier-workspace-storage-limits`` LD flag and parse it.
|
||||
|
||||
Returns a sparse ``{tier: megabytes}`` map built from whichever keys are
|
||||
valid in the flag payload, or ``None`` when the flag is unset / invalid /
|
||||
LD is unavailable. Callers merge whatever survives into
|
||||
:data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
|
||||
|
||||
The LD value is expected to be a JSON object keyed by tier enum name
|
||||
(``{"NO_TIER": 250, "PRO": 1024, "BUSINESS": 15360}``). Non-int or
|
||||
negative values are skipped so a broken key degrades to the code default
|
||||
instead of wiping out the limit.
|
||||
"""
|
||||
# Lazy import: rate_limit -> feature_flag -> settings -> ... -> rate_limit.
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_TIER_WORKSPACE_STORAGE_LIMITS.value, "system", None
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning(
|
||||
"Invalid LD value for copilot-tier-workspace-storage-limits "
|
||||
"(expected JSON object): %r",
|
||||
raw,
|
||||
)
|
||||
return None
|
||||
|
||||
parsed: dict[SubscriptionTier, int] = {}
|
||||
for key, value in raw.items():
|
||||
try:
|
||||
tier = SubscriptionTier(key)
|
||||
except ValueError:
|
||||
continue
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
logger.warning(
|
||||
"Invalid LD value for copilot-tier-workspace-storage-limits[%s]: %r",
|
||||
key,
|
||||
value,
|
||||
)
|
||||
continue
|
||||
if value < 0:
|
||||
logger.warning(
|
||||
"Negative LD value for copilot-tier-workspace-storage-limits[%s]: %r",
|
||||
key,
|
||||
value,
|
||||
)
|
||||
continue
|
||||
parsed[tier] = value
|
||||
return parsed or None
|
||||
|
||||
|
||||
async def get_workspace_storage_limits_mb() -> dict[str, int]:
|
||||
"""Return the effective ``{tier_value: megabytes}`` workspace limit map.
|
||||
|
||||
Honours the ``copilot-tier-workspace-storage-limits`` LD flag when set;
|
||||
missing tiers inherit :data:`_DEFAULT_TIER_WORKSPACE_STORAGE_MB`.
|
||||
Unparseable flag values or LD fetch failures fall back to the defaults.
|
||||
"""
|
||||
try:
|
||||
override = await _fetch_workspace_storage_limits_flag()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_workspace_storage_limits_mb: LD lookup failed", exc_info=True
|
||||
)
|
||||
override = None
|
||||
|
||||
merged: dict[SubscriptionTier, int] = dict(_DEFAULT_TIER_WORKSPACE_STORAGE_MB)
|
||||
if override:
|
||||
merged.update(override)
|
||||
return {tier.value: megabytes for tier, megabytes in merged.items()}
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
@@ -674,7 +752,13 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-
|
||||
async def get_workspace_storage_limit_bytes(user_id: str) -> int:
|
||||
"""Return the workspace storage cap in bytes for the user's subscription tier."""
|
||||
tier = await get_user_tier(user_id)
|
||||
mb = TIER_WORKSPACE_STORAGE_MB.get(tier, TIER_WORKSPACE_STORAGE_MB[DEFAULT_TIER])
|
||||
limits_mb = await get_workspace_storage_limits_mb()
|
||||
tier_key = getattr(tier, "value", str(tier))
|
||||
fallback_mb = limits_mb.get(
|
||||
DEFAULT_TIER.value,
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB[DEFAULT_TIER],
|
||||
)
|
||||
mb = limits_mb.get(tier_key, fallback_mb)
|
||||
return mb * 1024 * 1024
|
||||
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
_DEFAULT_TIER_WORKSPACE_STORAGE_MB,
|
||||
_DEFAULT_TIER_MULTIPLIERS,
|
||||
DEFAULT_TIER,
|
||||
TIER_MULTIPLIERS,
|
||||
TIER_WORKSPACE_STORAGE_MB,
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
SubscriptionTier,
|
||||
@@ -19,6 +19,7 @@ from .rate_limit import (
|
||||
_daily_reset_time,
|
||||
_fetch_cost_limits_flag,
|
||||
_fetch_tier_multipliers_flag,
|
||||
_fetch_workspace_storage_limits_flag,
|
||||
_weekly_key,
|
||||
_weekly_reset_time,
|
||||
acquire_reset_lock,
|
||||
@@ -26,6 +27,7 @@ from .rate_limit import (
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_tier_multipliers,
|
||||
get_workspace_storage_limits_mb,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
get_workspace_storage_limit_bytes,
|
||||
@@ -473,6 +475,85 @@ class TestGetTierMultipliers:
|
||||
assert result == {t.value: m for t, m in _DEFAULT_TIER_MULTIPLIERS.items()}
|
||||
|
||||
|
||||
class TestGetWorkspaceStorageLimits:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_flag_cache(self):
|
||||
"""Clear the LD flag cache between tests so patches don't leak."""
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_when_flag_unset(self):
|
||||
"""With no LD override, the resolver returns the default map."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_override(self):
|
||||
"""LD override populates targeted tiers; others inherit defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"NO_TIER": 300, "PRO": 2048},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["NO_TIER"] == 300
|
||||
assert result["PRO"] == 2048
|
||||
assert (
|
||||
result["BASIC"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BASIC]
|
||||
)
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_falls_back(self):
|
||||
"""A non-object LD value falls back to defaults."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value="broken",
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tier_key_and_invalid_values_skipped(self):
|
||||
"""Unknown tiers and invalid values degrade to defaults per key."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"NO_TIER": 300, "BOGUS": 999, "MAX": -1, "BUSINESS": "nope"},
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result["NO_TIER"] == 300
|
||||
assert result["MAX"] == _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.MAX]
|
||||
assert (
|
||||
result["BUSINESS"]
|
||||
== _DEFAULT_TIER_WORKSPACE_STORAGE_MB[SubscriptionTier.BUSINESS]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_failure_falls_back(self):
|
||||
"""LD lookup raising propagates to defaults, not up the call stack."""
|
||||
with patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LD SDK not initialized"),
|
||||
):
|
||||
result = await get_workspace_storage_limits_mb()
|
||||
assert result == {
|
||||
t.value: mb for t, mb in _DEFAULT_TIER_WORKSPACE_STORAGE_MB.items()
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_global_rate_limits — LD-flag cost limits parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1799,14 +1880,18 @@ class TestResetUserUsage:
|
||||
class TestWorkspaceStorageLimits:
|
||||
"""Tests for tier-based workspace storage limits."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_flag_cache(self):
|
||||
_fetch_workspace_storage_limits_flag.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
def test_every_subscription_tier_has_storage_limit(self):
|
||||
"""Adding a new SubscriptionTier without a storage limit should fail."""
|
||||
for tier in SubscriptionTier:
|
||||
assert tier in TIER_WORKSPACE_STORAGE_MB, (
|
||||
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB, (
|
||||
f"SubscriptionTier.{tier.name} has no entry in "
|
||||
f"TIER_WORKSPACE_STORAGE_MB — add one"
|
||||
f"_DEFAULT_TIER_WORKSPACE_STORAGE_MB — add one"
|
||||
)
|
||||
assert TIER_WORKSPACE_STORAGE_MB[tier] > 0
|
||||
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] > 0
|
||||
|
||||
def test_every_subscription_tier_has_rate_limit_multiplier(self):
|
||||
"""Adding a new SubscriptionTier without a rate limit multiplier should fail."""
|
||||
@@ -1815,11 +1900,15 @@ class TestWorkspaceStorageLimits:
|
||||
f"SubscriptionTier.{tier.name} has no entry in "
|
||||
f"_DEFAULT_TIER_MULTIPLIERS — add one"
|
||||
)
|
||||
assert _DEFAULT_TIER_MULTIPLIERS[tier] > 0
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
assert _DEFAULT_TIER_MULTIPLIERS[tier] == 0.0
|
||||
else:
|
||||
assert _DEFAULT_TIER_MULTIPLIERS[tier] > 0
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_mb",
|
||||
[
|
||||
(SubscriptionTier.NO_TIER, 250),
|
||||
(SubscriptionTier.BASIC, 250),
|
||||
(SubscriptionTier.PRO, 1024),
|
||||
(SubscriptionTier.MAX, 5 * 1024),
|
||||
@@ -1829,13 +1918,14 @@ class TestWorkspaceStorageLimits:
|
||||
)
|
||||
def test_tier_workspace_storage_mapping_covers_all_tiers(self, tier, expected_mb):
|
||||
"""Every tier has an explicit storage limit in the mapping."""
|
||||
assert tier in TIER_WORKSPACE_STORAGE_MB
|
||||
assert TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
|
||||
assert tier in _DEFAULT_TIER_WORKSPACE_STORAGE_MB
|
||||
assert _DEFAULT_TIER_WORKSPACE_STORAGE_MB[tier] == expected_mb
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"tier,expected_bytes",
|
||||
[
|
||||
(SubscriptionTier.NO_TIER, 250 * 1024 * 1024),
|
||||
(SubscriptionTier.BASIC, 250 * 1024 * 1024),
|
||||
(SubscriptionTier.PRO, 1024 * 1024 * 1024),
|
||||
(SubscriptionTier.MAX, 5 * 1024 * 1024 * 1024),
|
||||
@@ -1855,8 +1945,10 @@ class TestWorkspaceStorageLimits:
|
||||
assert result == expected_bytes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_storage_limit_bytes_defaults_to_basic_on_unknown(self):
|
||||
"""Unknown tier falls back to BASIC tier limit."""
|
||||
async def test_get_workspace_storage_limit_bytes_defaults_to_default_tier_on_unknown(
|
||||
self,
|
||||
):
|
||||
"""Unknown tier falls back to the default tier limit."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
return_value="UNKNOWN_TIER",
|
||||
|
||||
@@ -204,6 +204,68 @@ async def test_sync_subscription_from_stripe_cancelled():
|
||||
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_cancelled_applies_no_tier_storage_limit():
|
||||
"""After unsubscribe takes effect, workspace storage resolves against NO_TIER."""
|
||||
from backend.copilot.rate_limit import get_workspace_storage_limit_bytes
|
||||
|
||||
mock_user = _make_user(tier=SubscriptionTier.PRO)
|
||||
stripe_sub = {
|
||||
"id": "sub_old",
|
||||
"customer": "cus_123",
|
||||
"status": "canceled",
|
||||
"items": {"data": []},
|
||||
}
|
||||
empty_list = MagicMock()
|
||||
empty_list.data = []
|
||||
empty_list.has_more = False
|
||||
|
||||
async def _set_tier(_user_id: str, tier: SubscriptionTier) -> None:
|
||||
mock_user.subscriptionTier = tier
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
return_value=empty_list,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_set_tier,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda _user_id: mock_user.subscriptionTier,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_workspace_storage_limits_mb",
|
||||
new_callable=AsyncMock,
|
||||
return_value={
|
||||
"NO_TIER": 250,
|
||||
"BASIC": 500,
|
||||
"PRO": 1024,
|
||||
"MAX": 5 * 1024,
|
||||
"BUSINESS": 15 * 1024,
|
||||
"ENTERPRISE": 15 * 1024,
|
||||
},
|
||||
),
|
||||
patch.object(
|
||||
get_pending_subscription_change,
|
||||
"cache_delete",
|
||||
) as mock_pending_cache_delete,
|
||||
):
|
||||
await sync_subscription_from_stripe(stripe_sub)
|
||||
result = await get_workspace_storage_limit_bytes("user-1")
|
||||
|
||||
assert result == 250 * 1024 * 1024
|
||||
mock_pending_cache_delete.assert_called_once_with("user-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_cancelled_but_other_active_sub_exists():
|
||||
"""Cancelling sub_old must NOT downgrade the user if sub_new is still active.
|
||||
|
||||
@@ -44,6 +44,7 @@ class Flag(str, Enum):
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_COST_LIMITS = "copilot-cost-limits"
|
||||
COPILOT_TIER_MULTIPLIERS = "copilot-tier-multipliers"
|
||||
COPILOT_TIER_WORKSPACE_STORAGE_LIMITS = "copilot-tier-workspace-storage-limits"
|
||||
COPILOT_TIER_STRIPE_PRICES = "copilot-tier-stripe-prices"
|
||||
GRAPHITI_MEMORY = "graphiti-memory"
|
||||
# Stripe Product ID for top-up Checkout sessions. When unset (default),
|
||||
|
||||
@@ -207,6 +207,38 @@ async def test_write_file_quota_exceeded_raises_value_error(
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_rejects_upload_when_usage_already_exceeds_downgraded_limit(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""Downgrading below current usage should block further uploads until usage drops."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.util.workspace.scan_content_safe", new_callable=AsyncMock
|
||||
) as mock_scan,
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage_limit_bytes",
|
||||
return_value=250 * 1024 * 1024,
|
||||
),
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_total_size",
|
||||
return_value=300 * 1024 * 1024,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Storage limit exceeded"):
|
||||
await manager.write_file(filename="test.txt", content=b"hello")
|
||||
|
||||
mock_scan.assert_not_called()
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_80pct_warning_logged(manager, mock_storage, mock_db, caplog):
|
||||
"""write_file logs a warning when workspace usage crosses 80%."""
|
||||
|
||||
@@ -87,6 +87,26 @@ describe("UsagePanelContent", () => {
|
||||
expect(screen.getByText("No usage limits configured")).toBeDefined();
|
||||
});
|
||||
|
||||
it("still renders file storage when usage windows are null", () => {
|
||||
mockStorageData.mockReturnValue({
|
||||
data: {
|
||||
used_bytes: 100 * 1024 * 1024,
|
||||
limit_bytes: 250 * 1024 * 1024,
|
||||
used_percent: 40,
|
||||
file_count: 5,
|
||||
},
|
||||
});
|
||||
|
||||
render(
|
||||
<UsagePanelContent
|
||||
usage={makeUsage({ dailyPercent: null, weeklyPercent: null })}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("No usage limits configured")).toBeDefined();
|
||||
expect(screen.getByText("File storage")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders the reset button when daily limit is exhausted", () => {
|
||||
render(
|
||||
<UsagePanelContent
|
||||
|
||||
Reference in New Issue
Block a user