fix(backend/copilot): don't cache DEFAULT_TIER for non-existent users

When `_fetch_user_tier` is called for a user that doesn't exist yet, it
was returning `DEFAULT_TIER` (FREE) which the `@cached` decorator would
store for 5 minutes. If the user was then created with a higher tier
(e.g. PRO), they'd receive the stale cached FREE tier until TTL expiry.

Fix: raise `_UserNotFoundError` instead of returning `DEFAULT_TIER` when
the user record is missing or has no subscription tier. The `@cached`
decorator only caches successful returns, not exceptions. The outer
`get_user_tier` wrapper catches the exception and returns `DEFAULT_TIER`
without caching, so the next call re-queries the database.

Adds a regression test verifying that a not-found result is not cached
and a subsequent lookup after user creation returns the correct tier.
This commit is contained in:
Zamil Majdy
2026-04-02 11:33:10 +02:00
parent 0b77af29aa
commit 705e97ec46
2 changed files with 53 additions and 8 deletions

View File

@@ -383,6 +383,18 @@ async def record_token_usage(
)
class _UserNotFoundError(Exception):
"""Raised when a user record is missing or has no subscription tier.
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
decorator from storing the fallback value. This avoids a race condition
where a non-existent user's DEFAULT_TIER is cached, then the user is
created with a higher tier but receives the stale cached FREE tier for
up to 5 minutes.
"""
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached via Redis).
@@ -390,18 +402,16 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups are cached. Raises on DB errors so the
``@cached`` decorator does **not** store a fallback value.
Note: when the user is not found or ``subscriptionTier`` is ``None``,
``DEFAULT_TIER`` (FREE) is returned and **cached**. The Prisma schema
enforces ``@default(PRO)`` on the column, so ``None`` only occurs in
edge cases (e.g. partial row creation).
Only successful DB lookups of existing users with a valid tier are cached.
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
the ``@cached`` decorator does **not** store a fallback value. This
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
return DEFAULT_TIER
raise _UserNotFoundError(user_id)
async def get_user_tier(user_id: str) -> SubscriptionTier:

View File

@@ -503,6 +503,41 @@ class TestGetUserTier:
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_user_not_found_is_not_cached(self):
"""Non-existent user should NOT cache DEFAULT_TIER.
Regression test: when ``get_user_tier`` is called before a user record
exists, the DEFAULT_TIER fallback must not be cached. Otherwise, a
newly created user with a higher tier (e.g. PRO) would receive the
stale cached FREE tier for up to 5 minutes.
"""
# First call: user does not exist yet
missing_prisma = AsyncMock()
missing_prisma.find_unique = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=missing_prisma,
):
tier1 = await get_user_tier(_USER)
assert tier1 == DEFAULT_TIER
# Second call: user now exists with PRO tier
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
ok_prisma = AsyncMock()
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=ok_prisma,
):
tier2 = await get_user_tier(_USER)
# Should get PRO — the not-found result was not cached
assert tier2 == SubscriptionTier.PRO
# ---------------------------------------------------------------------------
# set_user_tier