mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix: address PR review - cache tier lookups, return tier from get_global_rate_limits, fix error handling
- Add @cached(ttl_seconds=300) to get_user_tier() to avoid DB hit on every chat turn - Change get_global_rate_limits() to return 3-tuple (daily, weekly, tier) so callers don't need redundant get_user_tier() calls - Remove redundant get_user_tier() calls from admin routes and chat /usage endpoint - Simplify `except (ValueError, Exception)` to `except Exception` - Handle prisma.errors.RecordNotFoundError in set_user_tier admin endpoint (404 vs 500) - Add test for user-not-found case on set_user_tier endpoint - Clear tier cache between tests to prevent stale cached results
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
@@ -58,11 +59,10 @@ async def get_user_rate_limit(
|
||||
"""Get a user's current usage and effective rate limits. Admin-only."""
|
||||
logger.info(f"Admin {admin_user_id} checking rate limit for user {user_id}")
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
tier = await get_user_tier(user_id)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
@@ -96,11 +96,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
tier = await get_user_tier(user_id)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
@@ -143,6 +142,10 @@ async def set_user_rate_limit_tier(
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except prisma.errors.RecordNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"User {request.user_id} not found"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
@@ -53,18 +53,13 @@ def test_get_rate_limit(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=RateLimitTier.STANDARD,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
@@ -96,18 +91,13 @@ def test_reset_user_usage_daily_only(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=0, weekly_used=3_000_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=RateLimitTier.STANDARD,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
@@ -142,18 +132,13 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, RateLimitTier.STANDARD),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=0, weekly_used=0),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=RateLimitTier.STANDARD,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
@@ -265,6 +250,27 @@ def test_set_user_tier_invalid_tier(
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for nonexistent user returns 404."""
|
||||
import prisma.errors
|
||||
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=prisma.errors.RecordNotFoundError({"error": "Record not found"}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "pro"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
|
||||
@@ -33,7 +33,6 @@ from backend.copilot.rate_limit import (
|
||||
check_rate_limit,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
@@ -427,10 +426,9 @@ async def get_copilot_usage(
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
tier = await get_user_tier(user_id)
|
||||
status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
@@ -540,7 +538,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
|
||||
@@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -257,9 +258,13 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300)
|
||||
async def get_user_tier(user_id: str) -> RateLimitTier:
|
||||
"""Look up the user's rate-limit tier from the database.
|
||||
|
||||
Results are cached for 5 minutes to avoid a DB round-trip on every
|
||||
rate-limit check (called on every chat turn).
|
||||
|
||||
Falls back to ``DEFAULT_TIER`` when the user record is missing,
|
||||
the field is ``None``, or the stored value is unrecognised.
|
||||
"""
|
||||
@@ -267,7 +272,7 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.rateLimitTier:
|
||||
return RateLimitTier(user.rateLimitTier)
|
||||
except (ValueError, Exception) as exc:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
|
||||
user_id[:8],
|
||||
@@ -278,7 +283,11 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database."""
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
"""
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"rateLimitTier": tier.value},
|
||||
@@ -289,7 +298,7 @@ async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, RateLimitTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
@@ -302,7 +311,7 @@ async def get_global_rate_limits(
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
@@ -332,7 +341,7 @@ async def get_global_rate_limits(
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
return daily, weekly
|
||||
return daily, weekly, tier
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
|
||||
@@ -384,6 +384,11 @@ class TestRateLimitTier:
|
||||
|
||||
|
||||
class TestGetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tier_from_db(self):
|
||||
"""Should return the tier stored in the user record."""
|
||||
@@ -498,10 +503,13 @@ class TestGetGlobalRateLimitsWithTiers:
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 2_500_000
|
||||
assert weekly == 12_500_000
|
||||
assert tier == RateLimitTier.STANDARD
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_tier_5x_multiplier(self):
|
||||
@@ -517,10 +525,13 @@ class TestGetGlobalRateLimitsWithTiers:
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 12_500_000
|
||||
assert weekly == 62_500_000
|
||||
assert tier == RateLimitTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_tier_25x_multiplier(self):
|
||||
@@ -536,7 +547,10 @@ class TestGetGlobalRateLimitsWithTiers:
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly = await get_global_rate_limits(_USER, 2_500_000, 12_500_000)
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 62_500_000
|
||||
assert weekly == 312_500_000
|
||||
assert tier == RateLimitTier.MAX
|
||||
|
||||
Reference in New Issue
Block a user