mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend/copilot): don't cache DEFAULT_TIER for non-existent users
When `_fetch_user_tier` is called for a user that doesn't exist yet, it was returning `DEFAULT_TIER` (FREE) which the `@cached` decorator would store for 5 minutes. If the user was then created with a higher tier (e.g. PRO), they'd receive the stale cached FREE tier until TTL expiry. Fix: raise `_UserNotFoundError` instead of returning `DEFAULT_TIER` when the user record is missing or has no subscription tier. The `@cached` decorator only caches successful returns, not exceptions. The outer `get_user_tier` wrapper catches the exception and returns `DEFAULT_TIER` without caching, so the next call re-queries the database. Adds a regression test verifying that a not-found result is not cached and a subsequent lookup after user creation returns the correct tier.
This commit is contained in:
@@ -383,6 +383,18 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
@@ -390,18 +402,16 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups are cached. Raises on DB errors so the
|
||||
``@cached`` decorator does **not** store a fallback value.
|
||||
|
||||
Note: when the user is not found or ``subscriptionTier`` is ``None``,
|
||||
``DEFAULT_TIER`` (FREE) is returned and **cached**. The Prisma schema
|
||||
enforces ``@default(PRO)`` on the column, so ``None`` only occurs in
|
||||
edge cases (e.g. partial row creation).
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
return DEFAULT_TIER
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
async def get_user_tier(user_id: str) -> SubscriptionTier:
|
||||
|
||||
@@ -503,6 +503,41 @@ class TestGetUserTier:
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found_is_not_cached(self):
|
||||
"""Non-existent user should NOT cache DEFAULT_TIER.
|
||||
|
||||
Regression test: when ``get_user_tier`` is called before a user record
|
||||
exists, the DEFAULT_TIER fallback must not be cached. Otherwise, a
|
||||
newly created user with a higher tier (e.g. PRO) would receive the
|
||||
stale cached FREE tier for up to 5 minutes.
|
||||
"""
|
||||
# First call: user does not exist yet
|
||||
missing_prisma = AsyncMock()
|
||||
missing_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=missing_prisma,
|
||||
):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Second call: user now exists with PRO tier
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO — the not-found result was not cached
|
||||
assert tier2 == SubscriptionTier.PRO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_user_tier
|
||||
|
||||
Reference in New Issue
Block a user