fix: address PR review - cache tier lookups, return tier from get_global_rate_limits, fix error handling

- Add @cached(ttl_seconds=300) to get_user_tier() to avoid DB hit on every chat turn
- Change get_global_rate_limits() to return 3-tuple (daily, weekly, tier) so callers
  don't need redundant get_user_tier() calls
- Remove redundant get_user_tier() calls from admin routes and chat /usage endpoint
- Simplify `except (ValueError, Exception)` to `except Exception`
- Handle prisma.errors.RecordNotFoundError in set_user_tier admin endpoint (404 vs 500)
- Add test for user-not-found case on set_user_tier endpoint
- Clear tier cache between tests to prevent stale cached results
This commit is contained in:
Zamil Majdy
2026-03-26 20:42:01 +07:00
parent 432ef5ab5e
commit ffb8d366d6
5 changed files with 64 additions and 34 deletions

View File

@@ -2,6 +2,7 @@
import logging
import prisma.errors
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, HTTPException, Security
from pydantic import BaseModel
@@ -58,11 +59,10 @@ async def get_user_rate_limit(
"""Get a user's current usage and effective rate limits. Admin-only."""
logger.info(f"Admin {admin_user_id} checking rate limit for user {user_id}")
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
tier = await get_user_tier(user_id)
return UserRateLimitResponse(
user_id=user_id,
@@ -96,11 +96,10 @@ async def reset_user_rate_limit(
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
tier = await get_user_tier(user_id)
return UserRateLimitResponse(
user_id=user_id,
@@ -143,6 +142,10 @@ async def set_user_rate_limit_tier(
)
try:
await set_user_tier(request.user_id, request.tier)
except prisma.errors.RecordNotFoundError as e:
raise HTTPException(
status_code=404, detail=f"User {request.user_id} not found"
) from e
except Exception as e:
logger.exception("Failed to set user tier")
raise HTTPException(status_code=500, detail="Failed to set tier") from e

View File

@@ -53,18 +53,13 @@ def test_get_rate_limit(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=RateLimitTier.STANDARD,
)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
@@ -96,18 +91,13 @@ def test_reset_user_usage_daily_only(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=0, weekly_used=3_000_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=RateLimitTier.STANDARD,
)
response = client.post(
"/admin/rate_limit/reset",
@@ -142,18 +132,13 @@ def test_reset_user_usage_daily_and_weekly(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=0, weekly_used=0),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=RateLimitTier.STANDARD,
)
response = client.post(
"/admin/rate_limit/reset",
@@ -265,6 +250,27 @@ def test_set_user_tier_invalid_tier(
assert response.status_code == 422
def test_set_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that setting tier for nonexistent user returns 404."""
import prisma.errors
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
side_effect=prisma.errors.RecordNotFoundError({"error": "Record not found"}),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "pro"},
)
assert response.status_code == 404
def test_set_user_tier_db_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,

View File

@@ -33,7 +33,6 @@ from backend.copilot.rate_limit import (
check_rate_limit,
get_global_rate_limits,
get_usage_status,
get_user_tier,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
@@ -427,10 +426,9 @@ async def get_copilot_usage(
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
tier = await get_user_tier(user_id)
status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
@@ -540,7 +538,7 @@ async def stream_chat_post(
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(

View File

@@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -257,9 +258,13 @@ async def record_token_usage(
)
@cached(maxsize=1000, ttl_seconds=300)
async def get_user_tier(user_id: str) -> RateLimitTier:
"""Look up the user's rate-limit tier from the database.
Results are cached for 5 minutes to avoid a DB round-trip on every
rate-limit check (called on every chat turn).
Falls back to ``DEFAULT_TIER`` when the user record is missing,
the field is ``None``, or the stored value is unrecognised.
"""
@@ -267,7 +272,7 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.rateLimitTier:
return RateLimitTier(user.rateLimitTier)
except (ValueError, Exception) as exc:
except Exception as exc:
logger.warning(
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
user_id[:8],
@@ -278,7 +283,11 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
"""Persist the user's rate-limit tier to the database."""
"""Persist the user's rate-limit tier to the database.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
"""
await PrismaUser.prisma().update(
where={"id": user_id},
data={"rateLimitTier": tier.value},
@@ -289,7 +298,7 @@ async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int]:
) -> tuple[int, int, RateLimitTier]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
The base limits (from LD or config) are multiplied by the user's
@@ -302,7 +311,7 @@ async def get_global_rate_limits(
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit) tuple.
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
@@ -332,7 +341,7 @@ async def get_global_rate_limits(
daily = daily * multiplier
weekly = weekly * multiplier
return daily, weekly
return daily, weekly, tier
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:

View File

@@ -384,6 +384,11 @@ class TestRateLimitTier:
class TestGetUserTier:
@pytest.fixture(autouse=True)
def _clear_tier_cache(self):
"""Clear the get_user_tier cache before each test."""
get_user_tier.cache_clear()
@pytest.mark.asyncio
async def test_returns_tier_from_db(self):
"""Should return the tier stored in the user record."""
@@ -498,10 +503,13 @@ class TestGetGlobalRateLimitsWithTiers:
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 2_500_000
assert weekly == 12_500_000
assert tier == RateLimitTier.STANDARD
@pytest.mark.asyncio
async def test_pro_tier_5x_multiplier(self):
@@ -517,10 +525,13 @@ class TestGetGlobalRateLimitsWithTiers:
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 12_500_000
assert weekly == 62_500_000
assert tier == RateLimitTier.PRO
@pytest.mark.asyncio
async def test_max_tier_25x_multiplier(self):
@@ -536,7 +547,10 @@ class TestGetGlobalRateLimitsWithTiers:
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 62_500_000
assert weekly == 312_500_000
assert tier == RateLimitTier.MAX