mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(platform): address PR review items for rate-limit tiering
- Change DEFAULT_TIER from PRO to FREE (fail-closed on DB errors) - Use shared_cache=True (Redis-backed) for _fetch_user_tier so tier changes propagate across pods immediately - Use TIER_MULTIPLIERS.get(tier, 1) to avoid KeyError on unknown tiers - Rename _tier to tier in routes.py where the variable is used, and to _ where it is truly unused - Add minimum 3-char query length for search_users to prevent user table enumeration - Use generated API client (getV2SearchUsersByNameOrEmail) instead of raw fetch() in useRateLimitManager - Remove unnecessary cast and fallback in RateLimitDisplay - Fix fragile call-count-based _ld_side_effect in tests to use flag_key matching pattern - Update test assertion for DEFAULT_TIER change (FREE not PRO)
This commit is contained in:
@@ -227,6 +227,11 @@ async def admin_search_users(
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=min(limit, 50))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -518,7 +518,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -556,7 +556,7 @@ async def reset_copilot_usage(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=_tier,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -632,7 +632,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=_tier,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -743,7 +743,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
|
||||
@@ -49,7 +49,7 @@ TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.ENTERPRISE: 60,
|
||||
}
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.PRO
|
||||
DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
@@ -377,18 +377,20 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300)
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached).
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups are cached. Raises on DB errors so the
|
||||
``@cached`` decorator does **not** store a fallback value.
|
||||
|
||||
Note: when the user is not found or ``subscriptionTier`` is ``None``,
|
||||
``DEFAULT_TIER`` is returned and **cached**. This is acceptable because
|
||||
the Prisma schema enforces ``@default(PRO)`` on the column, so ``None``
|
||||
only occurs in edge cases (e.g. partial row creation) and caching PRO
|
||||
for 5 minutes is safe.
|
||||
``DEFAULT_TIER`` (FREE) is returned and **cached**. The Prisma schema
|
||||
enforces ``@default(PRO)`` on the column, so ``None`` only occurs in
|
||||
edge cases (e.g. partial row creation).
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier:
|
||||
@@ -483,7 +485,7 @@ async def get_global_rate_limits(
|
||||
|
||||
# Apply tier multiplier
|
||||
tier = await get_user_tier(user_id)
|
||||
multiplier = TIER_MULTIPLIERS[tier]
|
||||
multiplier = TIER_MULTIPLIERS.get(tier, 1)
|
||||
if multiplier != 1:
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
@@ -360,8 +360,8 @@ class TestSubscriptionTier:
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.BUSINESS] == 20
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60
|
||||
|
||||
def test_default_tier_is_pro(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.PRO
|
||||
def test_default_tier_is_free(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.FREE
|
||||
|
||||
def test_usage_status_includes_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
@@ -601,15 +601,14 @@ class TestSetUserTier:
|
||||
class TestGetGlobalRateLimitsWithTiers:
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
"""Return an async side_effect that returns daily on first call, weekly on second."""
|
||||
call_count = 0
|
||||
"""Return an async side_effect that dispatches by flag_key."""
|
||||
|
||||
async def _side_effect(flag_key, user_id, default):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
return weekly
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@@ -717,12 +716,13 @@ class TestTierLimitsRespected:
|
||||
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
call_count = 0
|
||||
|
||||
async def _side_effect(flag_key, user_id, default):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return daily if call_count == 1 else weekly
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
|
||||
@@ -88,7 +88,8 @@ async def search_users(query: str, limit: int = 20) -> list[tuple[str, str | Non
|
||||
Returns a list of ``(user_id, email)`` tuples, up to *limit* results.
|
||||
Searches the User table directly — no dependency on credit history.
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
query = query.strip()
|
||||
if not query or len(query) < 3:
|
||||
return []
|
||||
users = await prisma.user.find_many(
|
||||
where={
|
||||
|
||||
@@ -42,7 +42,7 @@ export function RateLimitDisplay({
|
||||
const [isChangingTier, setIsChangingTier] = useState(false);
|
||||
const { toast } = useToast();
|
||||
|
||||
const currentTier = (data.tier as Tier) ?? "PRO";
|
||||
const currentTier = data.tier as Tier;
|
||||
|
||||
async function handleReset() {
|
||||
const msg = resetWeekly
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { SetUserTierRequest } from "@/app/api/__generated__/models/setUserT
|
||||
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
|
||||
import {
|
||||
getV2GetUserRateLimit,
|
||||
getV2SearchUsersByNameOrEmail,
|
||||
postV2ResetUserRateLimitUsage,
|
||||
postV2SetUserRateLimitTier,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
@@ -78,7 +79,6 @@ export function useRateLimitManager() {
|
||||
}
|
||||
}
|
||||
|
||||
/** Search users by partial name/email via the User table. */
|
||||
async function handleFuzzySearch(trimmed: string) {
|
||||
setIsSearching(true);
|
||||
setSearchResults([]);
|
||||
@@ -86,14 +86,18 @@ export function useRateLimitManager() {
|
||||
setRateLimitData(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/proxy/api/copilot/admin/rate_limit/search_users?query=${encodeURIComponent(trimmed)}&limit=20`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
const response = await getV2SearchUsersByNameOrEmail({
|
||||
query: trimmed,
|
||||
limit: 20,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to search users");
|
||||
}
|
||||
|
||||
const users: UserOption[] = await response.json();
|
||||
const users = (response.data ?? []).map((u) => ({
|
||||
user_id: u.user_id,
|
||||
user_email: u.user_email ?? u.user_id,
|
||||
}));
|
||||
if (users.length === 0) {
|
||||
toast({ title: "No results", description: "No users found." });
|
||||
}
|
||||
|
||||
@@ -8578,7 +8578,7 @@
|
||||
"weekly": { "$ref": "#/components/schemas/UsageWindow" },
|
||||
"tier": {
|
||||
"$ref": "#/components/schemas/SubscriptionTier",
|
||||
"default": "PRO"
|
||||
"default": "FREE"
|
||||
},
|
||||
"reset_cost": {
|
||||
"type": "integer",
|
||||
|
||||
Reference in New Issue
Block a user