fix(backend): address CodeRabbit review feedback

- Derive session reset time from Redis TTL instead of hardcoded 3h
- Add description to UsageWindow.limit documenting 0 = unlimited
- Compare balance < cost instead of balance <= 0 in pre-exec check
- Document TOCTOU behavior in check_rate_limit docstring
This commit is contained in:
Zamil Majdy
2026-03-12 21:10:52 +07:00
parent b6d863fcd2
commit c589cd0c43
5 changed files with 49 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ unavailable to avoid blocking users.
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
@@ -17,12 +17,17 @@ logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
# Session keys expire after 12 hours of inactivity
_SESSION_TTL_SECONDS = 43200 # 12 hours
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
@@ -73,9 +78,21 @@ def _weekly_reset_time() -> datetime:
)
def _session_reset_time() -> datetime:
"""Session limits reset after 3 hours of inactivity (matching session TTL)."""
return datetime.now(UTC) + timedelta(hours=3)
async def _session_reset_from_ttl(
redis: object, user_id: str, session_id: str
) -> datetime:
"""Derive session reset time from the Redis key's actual TTL.
Falls back to the configured TTL if the key doesn't exist or has no expiry.
"""
try:
ttl: int = await redis.ttl(_session_key(user_id, session_id)) # type: ignore[union-attr]
if ttl > 0:
return datetime.now(UTC) + timedelta(seconds=ttl)
except Exception:
pass
# Key doesn't exist or has no TTL — use the configured TTL
return datetime.now(UTC) + timedelta(seconds=_SESSION_TTL_SECONDS)
async def get_usage_status(
@@ -89,16 +106,18 @@ async def get_usage_status(
Args:
user_id: The user's ID.
session_id: The current session ID.
session_token_limit: Max tokens per session.
weekly_token_limit: Max tokens per week.
session_token_limit: Max tokens per session (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
session_resets_at = datetime.now(UTC) + timedelta(seconds=_SESSION_TTL_SECONDS)
try:
redis = await get_redis_async()
session_used = int(await redis.get(_session_key(user_id, session_id)) or 0)
weekly_used = int(await redis.get(_weekly_key(user_id)) or 0)
session_resets_at = await _session_reset_from_ttl(redis, user_id, session_id)
except Exception:
logger.warning("Redis unavailable for usage status, returning zeros")
session_used = 0
@@ -108,7 +127,7 @@ async def get_usage_status(
session=UsageWindow(
used=session_used,
limit=session_token_limit,
resets_at=_session_reset_time(),
resets_at=session_resets_at,
),
weekly=UsageWindow(
used=weekly_used,
@@ -126,6 +145,12 @@ async def check_rate_limit(
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
try:
@@ -137,7 +162,8 @@ async def check_rate_limit(
return
if session_token_limit > 0 and session_used >= session_token_limit:
raise RateLimitExceeded("session", _session_reset_time())
resets_at = await _session_reset_from_ttl(redis, user_id, session_id)
raise RateLimitExceeded("session", resets_at)
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time())
@@ -165,10 +191,10 @@ async def record_token_usage(
redis = await get_redis_async()
pipe = redis.pipeline()
# Session counter (reset with session TTL — 12 hours)
# Session counter (expires after configured TTL)
s_key = _session_key(user_id, session_id)
pipe.incrby(s_key, total)
pipe.expire(s_key, 43200) # 12 hours
pipe.expire(s_key, _SESSION_TTL_SECONDS)
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id)

View File

@@ -54,6 +54,7 @@ class TestGetUsageStatus:
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
mock_redis.ttl = AsyncMock(return_value=7200) # 2 hours remaining
with patch(
"backend.copilot.rate_limit.get_redis_async",
@@ -107,6 +108,7 @@ class TestCheckRateLimit:
async def test_raises_when_session_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
mock_redis.ttl = AsyncMock(return_value=3600) # 1 hour remaining
with patch(
"backend.copilot.rate_limit.get_redis_async",

View File

@@ -118,7 +118,7 @@ async def execute_block(
if cost > 0:
credit_model = await get_user_credit_model(user_id)
balance = await credit_model.get_credits(user_id)
if balance <= 0:
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "

View File

@@ -81,10 +81,10 @@ class TestExecuteBlockCreditCharging:
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance <= 0."""
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
mock_credit = AsyncMock()
mock_credit.get_credits = AsyncMock(return_value=0)
mock_credit.get_credits = AsyncMock(return_value=5) # balance < cost (10)
with (
_patch_workspace(),
@@ -145,10 +145,10 @@ class TestExecuteBlockCreditCharging:
block = _make_block()
mock_credit = AsyncMock()
mock_credit.get_credits = AsyncMock(return_value=5)
mock_credit.get_credits = AsyncMock(return_value=15) # passes pre-check
mock_credit.spend_credits = AsyncMock(
side_effect=InsufficientBalanceError("Low balance", _USER, 5, 10)
)
) # fails during actual charge (race with concurrent spend)
with (
_patch_workspace(),

View File

@@ -14256,7 +14256,11 @@
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": { "type": "integer", "title": "Limit" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",