fix(platform): address PR review items for rate-limit tiering

- Change DEFAULT_TIER from PRO to FREE (fail-closed on DB errors)
- Use shared_cache=True (Redis-backed) for _fetch_user_tier so tier
  changes propagate across pods immediately
- Use TIER_MULTIPLIERS.get(tier, 1) to avoid KeyError on unknown tiers
- Rename _tier to tier in routes.py where the variable is used, and
  to _ where it is truly unused
- Add minimum 3-char query length for search_users to prevent user
  table enumeration
- Use generated API client (getV2SearchUsersByNameOrEmail) instead of
  raw fetch() in useRateLimitManager
- Remove unnecessary cast and fallback in RateLimitDisplay
- Fix fragile call-count-based _ld_side_effect in tests to use
  flag_key matching pattern
- Update test assertion for DEFAULT_TIER change (FREE not PRO)
This commit is contained in:
Zamil Majdy
2026-04-02 06:28:36 +02:00
parent f4571cb9e1
commit 1de2a7fb09
8 changed files with 47 additions and 35 deletions

View File

@@ -227,6 +227,11 @@ async def admin_search_users(
Queries the User table directly — returns results even for users
without credit transaction history.
"""
if len(query.strip()) < 3:
raise HTTPException(
status_code=400,
detail="Search query must be at least 3 characters.",
)
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
results = await search_users(query, limit=min(limit, 50))
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]

View File

@@ -518,7 +518,7 @@ async def reset_copilot_usage(
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
@@ -556,7 +556,7 @@ async def reset_copilot_usage(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
tier=_tier,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
@@ -632,7 +632,7 @@ async def reset_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=_tier,
tier=tier,
)
return RateLimitResetResponse(
@@ -743,7 +743,7 @@ async def stream_chat_post(
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(

View File

@@ -49,7 +49,7 @@ TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
SubscriptionTier.ENTERPRISE: 60,
}
DEFAULT_TIER = SubscriptionTier.PRO
DEFAULT_TIER = SubscriptionTier.FREE
class UsageWindow(BaseModel):
@@ -377,18 +377,20 @@ async def record_token_usage(
)
@cached(maxsize=1000, ttl_seconds=300)
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached).
"""Fetch the user's rate-limit tier from the database (cached via Redis).
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups are cached. Raises on DB errors so the
``@cached`` decorator does **not** store a fallback value.
Note: when the user is not found or ``subscriptionTier`` is ``None``,
``DEFAULT_TIER`` is returned and **cached**. This is acceptable because
the Prisma schema enforces ``@default(PRO)`` on the column, so ``None``
only occurs in edge cases (e.g. partial row creation) and caching PRO
for 5 minutes is safe.
``DEFAULT_TIER`` (FREE) is returned and **cached**. The Prisma schema
enforces ``@default(PRO)`` on the column, so ``None`` only occurs in
edge cases (e.g. partial row creation).
"""
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier:
@@ -483,7 +485,7 @@ async def get_global_rate_limits(
# Apply tier multiplier
tier = await get_user_tier(user_id)
multiplier = TIER_MULTIPLIERS[tier]
multiplier = TIER_MULTIPLIERS.get(tier, 1)
if multiplier != 1:
daily = daily * multiplier
weekly = weekly * multiplier

View File

@@ -360,8 +360,8 @@ class TestSubscriptionTier:
assert TIER_MULTIPLIERS[SubscriptionTier.BUSINESS] == 20
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60
def test_default_tier_is_pro(self):
assert DEFAULT_TIER == SubscriptionTier.PRO
def test_default_tier_is_free(self):
assert DEFAULT_TIER == SubscriptionTier.FREE
def test_usage_status_includes_tier(self):
now = datetime.now(UTC)
@@ -601,15 +601,14 @@ class TestSetUserTier:
class TestGetGlobalRateLimitsWithTiers:
@staticmethod
def _ld_side_effect(daily: int, weekly: int):
"""Return an async side_effect that returns daily on first call, weekly on second."""
call_count = 0
"""Return an async side_effect that dispatches by flag_key."""
async def _side_effect(flag_key, user_id, default):
nonlocal call_count
call_count += 1
if call_count == 1:
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
if "daily" in flag_key.lower():
return daily
return weekly
if "weekly" in flag_key.lower():
return weekly
return default
return _side_effect
@@ -717,12 +716,13 @@ class TestTierLimitsRespected:
@staticmethod
def _ld_side_effect(daily: int, weekly: int):
call_count = 0
async def _side_effect(flag_key, user_id, default):
nonlocal call_count
call_count += 1
return daily if call_count == 1 else weekly
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
if "daily" in flag_key.lower():
return daily
if "weekly" in flag_key.lower():
return weekly
return default
return _side_effect

View File

@@ -88,7 +88,8 @@ async def search_users(query: str, limit: int = 20) -> list[tuple[str, str | Non
Returns a list of ``(user_id, email)`` tuples, up to *limit* results.
Searches the User table directly — no dependency on credit history.
"""
if not query or not query.strip():
query = query.strip()
if not query or len(query) < 3:
return []
users = await prisma.user.find_many(
where={

View File

@@ -42,7 +42,7 @@ export function RateLimitDisplay({
const [isChangingTier, setIsChangingTier] = useState(false);
const { toast } = useToast();
const currentTier = (data.tier as Tier) ?? "PRO";
const currentTier = data.tier as Tier;
async function handleReset() {
const msg = resetWeekly

View File

@@ -6,6 +6,7 @@ import type { SetUserTierRequest } from "@/app/api/__generated__/models/setUserT
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import {
getV2GetUserRateLimit,
getV2SearchUsersByNameOrEmail,
postV2ResetUserRateLimitUsage,
postV2SetUserRateLimitTier,
} from "@/app/api/__generated__/endpoints/admin/admin";
@@ -78,7 +79,6 @@ export function useRateLimitManager() {
}
}
/** Search users by partial name/email via the User table. */
async function handleFuzzySearch(trimmed: string) {
setIsSearching(true);
setSearchResults([]);
@@ -86,14 +86,18 @@ export function useRateLimitManager() {
setRateLimitData(null);
try {
const response = await fetch(
`/api/proxy/api/copilot/admin/rate_limit/search_users?query=${encodeURIComponent(trimmed)}&limit=20`,
);
if (!response.ok) {
const response = await getV2SearchUsersByNameOrEmail({
query: trimmed,
limit: 20,
});
if (response.status !== 200) {
throw new Error("Failed to search users");
}
const users: UserOption[] = await response.json();
const users = (response.data ?? []).map((u) => ({
user_id: u.user_id,
user_email: u.user_email ?? u.user_id,
}));
if (users.length === 0) {
toast({ title: "No results", description: "No users found." });
}

View File

@@ -8578,7 +8578,7 @@
"weekly": { "$ref": "#/components/schemas/UsageWindow" },
"tier": {
"$ref": "#/components/schemas/SubscriptionTier",
"default": "PRO"
"default": "FREE"
},
"reset_cost": {
"type": "integer",