mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): address CodeRabbit review feedback
- Derive session reset time from Redis TTL instead of hardcoded 3h - Add description to UsageWindow.limit documenting 0 = unlimited - Compare balance < cost instead of balance <= 0 in pre-exec check - Document TOCTOU behavior in check_rate_limit docstring
This commit is contained in:
@@ -8,7 +8,7 @@ unavailable to avoid blocking users.
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
@@ -17,12 +17,17 @@ logger = logging.getLogger(__name__)
|
||||
# Redis key prefixes
|
||||
_PREFIX = "copilot:usage"
|
||||
|
||||
# Session keys expire after 12 hours of inactivity
|
||||
_SESSION_TTL_SECONDS = 43200 # 12 hours
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
@@ -73,9 +78,21 @@ def _weekly_reset_time() -> datetime:
|
||||
)
|
||||
|
||||
|
||||
def _session_reset_time() -> datetime:
|
||||
"""Session limits reset after 3 hours of inactivity (matching session TTL)."""
|
||||
return datetime.now(UTC) + timedelta(hours=3)
|
||||
async def _session_reset_from_ttl(
|
||||
redis: object, user_id: str, session_id: str
|
||||
) -> datetime:
|
||||
"""Derive session reset time from the Redis key's actual TTL.
|
||||
|
||||
Falls back to the configured TTL if the key doesn't exist or has no expiry.
|
||||
"""
|
||||
try:
|
||||
ttl: int = await redis.ttl(_session_key(user_id, session_id)) # type: ignore[union-attr]
|
||||
if ttl > 0:
|
||||
return datetime.now(UTC) + timedelta(seconds=ttl)
|
||||
except Exception:
|
||||
pass
|
||||
# Key doesn't exist or has no TTL — use the configured TTL
|
||||
return datetime.now(UTC) + timedelta(seconds=_SESSION_TTL_SECONDS)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
@@ -89,16 +106,18 @@ async def get_usage_status(
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
session_id: The current session ID.
|
||||
session_token_limit: Max tokens per session.
|
||||
weekly_token_limit: Max tokens per week.
|
||||
session_token_limit: Max tokens per session (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
session_resets_at = datetime.now(UTC) + timedelta(seconds=_SESSION_TTL_SECONDS)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
session_used = int(await redis.get(_session_key(user_id, session_id)) or 0)
|
||||
weekly_used = int(await redis.get(_weekly_key(user_id)) or 0)
|
||||
session_resets_at = await _session_reset_from_ttl(redis, user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning("Redis unavailable for usage status, returning zeros")
|
||||
session_used = 0
|
||||
@@ -108,7 +127,7 @@ async def get_usage_status(
|
||||
session=UsageWindow(
|
||||
used=session_used,
|
||||
limit=session_token_limit,
|
||||
resets_at=_session_reset_time(),
|
||||
resets_at=session_resets_at,
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
@@ -126,6 +145,12 @@ async def check_rate_limit(
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
try:
|
||||
@@ -137,7 +162,8 @@ async def check_rate_limit(
|
||||
return
|
||||
|
||||
if session_token_limit > 0 and session_used >= session_token_limit:
|
||||
raise RateLimitExceeded("session", _session_reset_time())
|
||||
resets_at = await _session_reset_from_ttl(redis, user_id, session_id)
|
||||
raise RateLimitExceeded("session", resets_at)
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time())
|
||||
@@ -165,10 +191,10 @@ async def record_token_usage(
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline()
|
||||
|
||||
# Session counter (reset with session TTL — 12 hours)
|
||||
# Session counter (expires after configured TTL)
|
||||
s_key = _session_key(user_id, session_id)
|
||||
pipe.incrby(s_key, total)
|
||||
pipe.expire(s_key, 43200) # 12 hours
|
||||
pipe.expire(s_key, _SESSION_TTL_SECONDS)
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id)
|
||||
|
||||
@@ -54,6 +54,7 @@ class TestGetUsageStatus:
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
mock_redis.ttl = AsyncMock(return_value=7200) # 2 hours remaining
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
@@ -107,6 +108,7 @@ class TestCheckRateLimit:
|
||||
async def test_raises_when_session_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
mock_redis.ttl = AsyncMock(return_value=3600) # 1 hour remaining
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
|
||||
@@ -118,7 +118,7 @@ async def execute_block(
|
||||
if cost > 0:
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
balance = await credit_model.get_credits(user_id)
|
||||
if balance <= 0:
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
|
||||
@@ -81,10 +81,10 @@ class TestExecuteBlockCreditCharging:
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance <= 0."""
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
mock_credit = AsyncMock()
|
||||
mock_credit.get_credits = AsyncMock(return_value=0)
|
||||
mock_credit.get_credits = AsyncMock(return_value=5) # balance < cost (10)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
@@ -145,10 +145,10 @@ class TestExecuteBlockCreditCharging:
|
||||
|
||||
block = _make_block()
|
||||
mock_credit = AsyncMock()
|
||||
mock_credit.get_credits = AsyncMock(return_value=5)
|
||||
mock_credit.get_credits = AsyncMock(return_value=15) # passes pre-check
|
||||
mock_credit.spend_credits = AsyncMock(
|
||||
side_effect=InsufficientBalanceError("Low balance", _USER, 5, 10)
|
||||
)
|
||||
) # fails during actual charge (race with concurrent spend)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
|
||||
@@ -14256,7 +14256,11 @@
|
||||
"UsageWindow": {
|
||||
"properties": {
|
||||
"used": { "type": "integer", "title": "Used" },
|
||||
"limit": { "type": "integer", "title": "Limit" },
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"title": "Limit",
|
||||
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||
},
|
||||
"resets_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
|
||||
Reference in New Issue
Block a user