fix(backend/copilot): address autogpt-reviewer should-fix items for rate limiting and token tracking

- Normalize total_tokens semantics: both baseline and SDK now compute
  total_tokens = prompt + completion, with cache fields kept separate
- Resolve credit model once in helpers.py execute_block to avoid
  duplicate get_user_credit_model() calls per block execution
- Add 429 test coverage: weekly rate limit and reset-time assertion tests
- Extract shared token tracking into token_tracking.py to DRY ~50 lines
  of duplicated usage persistence + rate-limit recording logic
- Add early return in check_rate_limit when both limits are 0 (unlimited)
  to skip unnecessary Redis round-trip
- Switch all chat route auth from Depends to Security for OpenAPI spec
  consistency
- Use transaction=True for Redis pipeline incrby+expire to ensure
  atomicity
This commit is contained in:
Zamil Majdy
2026-03-14 22:59:32 +07:00
parent b9be577904
commit a0d534f24b
8 changed files with 255 additions and 105 deletions

View File

@@ -8,7 +8,7 @@ from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
@@ -235,9 +235,10 @@ async def list_sessions(
@router.post(
"/sessions",
dependencies=[Security(auth.requires_user)],
)
async def create_session(
user_id: Annotated[str, Depends(auth.get_user_id)],
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -356,7 +357,7 @@ async def update_session_title_route(
)
async def get_session(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
user_id: Annotated[str | None, Security(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
@@ -437,7 +438,7 @@ async def get_copilot_usage(
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
user_id: Annotated[str | None, Security(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
@@ -482,7 +483,7 @@ async def cancel_session_task(
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
user_id: str | None = Security(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
@@ -507,6 +508,9 @@ async def stream_chat_post(
import asyncio
import time
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
@@ -773,7 +777,7 @@ async def stream_chat_post(
)
async def resume_session_stream(
session_id: str,
user_id: str | None = Depends(auth.get_user_id),
user_id: str | None = Security(auth.get_user_id),
):
"""
Resume an active stream for a session.

View File

@@ -1,4 +1,4 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
@@ -252,6 +252,77 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
# Ensure the rate-limit branch is entered by setting a non-zero limit.
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
resets_at = datetime.now(UTC) + timedelta(days=3)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("weekly", resets_at),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"].lower()
assert "weekly" in detail
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded(
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"]
assert "2h" in detail
assert "Resets in" in detail
# ─── Usage endpoint ───────────────────────────────────────────────────

View File

@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -46,6 +44,7 @@ from backend.copilot.service import (
client,
config,
)
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
@@ -460,37 +459,17 @@ async def stream_chat_completion_baseline(
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we cannot
# break out cache_read/cache_creation weights. Users on the baseline
# path may be slightly over-counted vs the SDK path.
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist token usage to session and record for rate limiting.
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
# cannot break out cache_read/cache_creation weights. Users on the
# baseline path may be slightly over-counted vs the SDK path.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:

View File

@@ -101,7 +101,10 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Calculate usage from token counts
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
# completion totals. This is a known limitation.
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(

View File

@@ -145,6 +145,11 @@ async def check_rate_limit(
Fails open: if Redis is unavailable, allows the request.
"""
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
# round-trip entirely.
if daily_token_limit <= 0 and weekly_token_limit <= 0:
return
now = datetime.now(UTC)
try:
redis = await get_redis_async()
@@ -158,6 +163,7 @@ async def check_rate_limit(
logger.warning("Redis unavailable for rate limit check, allowing request")
return
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
@@ -192,6 +198,11 @@ async def record_token_usage(
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
@@ -219,7 +230,10 @@ async def record_token_usage(
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Use transaction=True (MULTI/EXEC) so that each incrby+expire
# pair is atomic. With transaction=False a crash between incrby
# and expire could leave a key without a TTL, leaking memory.
pipe = redis.pipeline(transaction=True)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)

View File

@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -63,6 +61,7 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..token_tracking import persist_and_record_usage
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
@@ -1389,12 +1388,13 @@ async def stream_chat_completion_sdk(
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
# total_tokens = prompt (uncached input) + completion (output).
# Cache tokens are tracked separately and excluded from total
# so that the semantics match the baseline path (OpenRouter)
# which folds cache into prompt_tokens. Keeping total_tokens
# = prompt + completion everywhere makes cross-path comparisons
# and session-level aggregation consistent.
total_tokens = turn_prompt_tokens + turn_completion_tokens
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
@@ -1470,49 +1470,16 @@ async def stream_chat_completion_sdk(
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and (turn_prompt_tokens > 0 or turn_completion_tokens > 0):
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator

View File

@@ -0,0 +1,101 @@
"""Shared token-usage persistence and rate-limit recording.
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
This module extracts that common logic so both paths stay in sync.
"""
import logging
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
async def persist_and_record_usage(
*,
session: ChatSession | None,
user_id: str | None,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
log_prefix: str = "",
cost_usd: float | str | None = None,
) -> int:
"""Persist token usage to session and record for rate limiting.
Args:
session: The chat session to append usage to (may be None on error).
user_id: User ID for rate-limit counters (skipped if None).
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
Returns:
The computed total_tokens (prompt + completion; cache excluded).
"""
if prompt_tokens <= 0 and completion_tokens <= 0:
return 0
# total_tokens = prompt + completion. Cache tokens are tracked
# separately and excluded from total so both baseline and SDK
# paths share the same semantics.
total_tokens = prompt_tokens + completion_tokens
if session is not None:
session.usage.append(
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
)
if cache_read_tokens or cache_creation_tokens:
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
total_tokens,
cost_usd,
)
else:
logger.info(
"%s Turn usage: prompt=%d, completion=%d, total=%d",
log_prefix,
prompt_tokens,
completion_tokens,
total_tokens,
)
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
return total_tokens

View File

@@ -8,8 +8,8 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data.credit import UsageTransactionMetadata
from backend.data.db_accessors import credit_db, workspace_db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
@@ -123,9 +123,12 @@ async def execute_block(
# between the balance check and the post-execution charge.
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
credits = credit_db()
# Resolve the credit model once and reuse for both the balance
# check and the post-execution charge to avoid redundant
# LaunchDarkly flag evaluations.
credit_model = await get_user_credit_model(user_id)
if has_cost:
balance = await credits.get_credits(user_id)
balance = await credit_model.get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
@@ -146,7 +149,7 @@ async def execute_block(
# Charge credits for block execution
if has_cost:
try:
await credits.spend_credits(
await credit_model.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
@@ -160,17 +163,25 @@ async def execute_block(
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
except InsufficientBalanceError as e:
# Concurrent spend drained balance after our pre-check passed.
# Block already executed (with possible side effects), so return
# its output but log the billing leak for monitoring.
logger.warning(
"BILLING_LEAK: Post-exec credit charge failed for block %s "
"(cost=%d, user=%s, node_exec=%s)",
block.name,
cost,
user_id[:8],
logger.error(
"BILLING_LEAK: block executed but credit charge failed "
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
user_id,
block_id,
node_exec_id,
cost,
e,
extra={
"json_fields": {
"billing_leak": True,
"user_id": user_id,
"cost": str(cost),
}
},
)
return BlockOutputResponse(