fix: remove extraneous changes accidentally picked up from PR #12385

Revert copilot, chat routes, and frontend files to match origin/dev.
This PR should only contain store search CamelCase splitting changes.
This commit is contained in:
Zamil Majdy
2026-03-14 10:25:02 +07:00
parent 19e79dd236
commit 378bd3afcc
24 changed files with 91 additions and 1562 deletions

View File

@@ -27,12 +27,6 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -126,8 +120,6 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -397,10 +389,6 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -408,26 +396,6 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -528,17 +496,6 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).

View File

@@ -1,6 +1,5 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -252,74 +251,6 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -38,7 +36,6 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -49,11 +46,7 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -228,9 +221,6 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -242,7 +232,6 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -253,18 +242,7 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -433,57 +411,6 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 1
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.error(
"[Baseline] Failed to record token usage "
"(tokens=%d, user=%s): %s",
turn_prompt_tokens + turn_completion_tokens,
user_id[:8] if user_id else "?",
usage_err,
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -494,16 +421,4 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,18 +70,6 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,

View File

@@ -73,9 +73,6 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -1,254 +0,0 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0 and the ``or 7`` fallback
pushes to *next* Monday. This means Monday's usage counts toward the
current week (which started the previous Monday), matching ISO week semantics.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_usage_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns:
(daily_used, weekly_used) parsed as ints.
Raises:
RedisError / ConnectionError / OSError on Redis failure.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_usage_counters(user_id, now)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for usage status, returning zeros")
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_usage_counters(user_id, now)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for rate limit check, allowing request")
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
)

View File

@@ -1,334 +0,0 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,29 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
totalTokens: int = Field(..., description="Total number of tokens")
class StreamError(StreamBaseResponse):

View File

@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -56,7 +54,6 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
@@ -739,14 +736,6 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
total_tokens = 0
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -1123,7 +1112,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details and capture token usage
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1142,33 +1131,6 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with the
# CLI's active context so they stay identical.
@@ -1385,25 +1347,6 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1468,56 +1411,6 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once in the try block (for StreamUsage)
# and reused here to keep the formula DRY.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
if not total_tokens:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and (turn_prompt_tokens > 0 or turn_completion_tokens > 0):
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).

View File

@@ -8,15 +8,11 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.exceptions import BlockError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
@@ -25,26 +21,6 @@ from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
@@ -139,37 +115,6 @@ async def execute_block(
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Charge credits BEFORE execution to prevent TOCTOU race.
# If a concurrent spend drains balance between check and charge,
# _spend_credits raises InsufficientBalanceError before any side
# effects (API calls, data mutations) happen.
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
try:
await _spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -188,16 +133,16 @@ async def execute_block(
)
except BlockError as e:
logger.warning("Block execution failed: %s", e)
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message="An internal error occurred while executing the block.",
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -1,193 +1,24 @@
"""Tests for execute_block — credit charging and type coercion."""
"""Tests for execute_block type coercion in helpers.py.
Verifies that execute_block() coerces string input values to match the block's
expected input types, mirroring the executor's validate_exec() logic.
This is critical for @@agptfile: expansion, where file content is always a string
but the block may expect structured types (e.g. list[list[str]]).
"""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits before execution."""
block = _make_block()
mock_spend = AsyncMock()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution charge should return ErrorResponse on InsufficientBalanceError."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError("Low balance", _USER, 5, 10),
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
async def test_returns_error_when_spend_credits_raises_insufficient_balance(self):
"""Pre-exec charge failure returns ErrorResponse -- block never executes."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError("Low balance", _USER, 5, 10),
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
# Credits charged before execution -- insufficient balance prevents execution
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
from backend.copilot.tools.models import BlockOutputResponse
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
"""Create a mock input_schema with model_fields matching the given annotations."""
schema = MagicMock()
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
model_fields = {}
for name, ann in annotations.items():
field = MagicMock()
@@ -197,7 +28,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
return schema
def _make_coerce_block(
def _make_block(
block_id: str,
name: str,
annotations: dict[str, Any],
@@ -229,7 +60,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_coerce_block(
block = _make_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
@@ -259,6 +90,7 @@ async def test_coerce_json_string_to_nested_list():
assert isinstance(response, BlockOutputResponse)
assert response.success is True
# Verify the input was coerced from string to list[list[str]]
assert block._captured_inputs["values"] == [
["Name", "Score"],
["Alice", "90"],
@@ -271,7 +103,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_coerce_block(
block = _make_block(
"list-block",
"List Block",
{"items": list[str]},
@@ -303,7 +135,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_coerce_block(
block = _make_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
@@ -335,7 +167,7 @@ async def test_coerce_json_string_to_dict():
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_coerce_block(
block = _make_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
@@ -369,7 +201,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_coerce_block(
block = _make_block(
"int-block",
"Int Block",
{"count": int},
@@ -402,7 +234,7 @@ async def test_coerce_string_to_int():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_coerce_block(
block = _make_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
@@ -428,13 +260,14 @@ async def test_coerce_skips_none_values():
)
assert isinstance(response, BlockOutputResponse)
# 'data' was not provided, so it should not appear in captured inputs
assert "data" not in block._captured_inputs
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_coerce_block(
block = _make_block(
"union-block",
"Union Block",
{"content": str | list[str]},
@@ -460,6 +293,7 @@ async def test_coerce_union_type_preserves_valid_member():
)
assert isinstance(response, BlockOutputResponse)
# list[str] should NOT be stringified to '["a", "b"]'
assert block._captured_inputs["content"] == ["a", "b"]
assert isinstance(block._captured_inputs["content"], list)
@@ -467,7 +301,7 @@ async def test_coerce_union_type_preserves_valid_member():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_coerce_block(
block = _make_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},
@@ -485,6 +319,7 @@ async def test_coerce_inner_elements_of_generic():
response = await execute_block(
block=block,
block_id="inner-coerce",
# Inner elements are ints, but target is list[str]
input_data={"values": [1, 2, 3]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
@@ -493,5 +328,6 @@ async def test_coerce_inner_elements_of_generic():
)
assert isinstance(response, BlockOutputResponse)
# Inner elements should be coerced from int to str
assert block._captured_inputs["values"] == ["1", "2", "3"]
assert all(isinstance(v, str) for v in block._captured_inputs["values"])

View File

@@ -512,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -1,8 +1,14 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { UploadSimple } from "@phosphor-icons/react";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -86,6 +92,7 @@ export function CopilotPage() {
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -141,6 +148,38 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>

View File

@@ -2,6 +2,7 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -38,6 +40,7 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -60,6 +63,7 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -30,6 +30,7 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -37,7 +37,6 @@ import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
@@ -257,10 +256,11 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="flex items-center">
<UsageLimits />
<div className="relative left-5 flex items-center gap-1">
<NotificationToggle />
<SidebarTrigger />
<div className="relative left-1">
<SidebarTrigger />
</div>
</div>
</div>
{sessionId ? (

View File

@@ -7,7 +7,6 @@ import {
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
@@ -49,7 +48,10 @@ export function NotificationToggle() {
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Notification settings">
<button
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
aria-label="Notification settings"
>
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
@@ -57,7 +59,7 @@ export function NotificationToggle() {
) : (
<Bell className="!size-5" />
)}
</Button>
</button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">

View File

@@ -1,147 +0,0 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { useUsageLimits } from "./useUsageLimits";
const MS_PER_MINUTE = 60_000;
const MS_PER_HOUR = 3_600_000;
const HOURS_PER_DAY = 24;
function formatResetTime(resetsAt: Date | string): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const now = new Date();
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / MS_PER_HOUR);
// Under 24h: show relative time ("in 4h 23m")
if (hours < HOURS_PER_DAY) {
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<a
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</a>
)}
</div>
);
}
export function UsageLimits() {
const { data: usage, isLoading } = useUsageLimits();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -1,121 +0,0 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the useUsageLimits hook
const mockUseUsageLimits = vi.fn();
vi.mock("../useUsageLimits", () => ({
useUsageLimits: () => mockUseUsageLimits(),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseUsageLimits.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

View File

@@ -1,12 +0,0 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
export function useUsageLimits() {
return useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
}

View File

@@ -1,5 +1,4 @@
import {
getGetV2GetCopilotUsageQueryKey,
getGetV2GetSessionQueryKey,
postV2CancelSessionTask,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -308,9 +307,6 @@ export function useCopilotStream({
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
queryClient.invalidateQueries({
queryKey: getGetV2GetCopilotUsageQueryKey(),
});
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;

View File

@@ -11,8 +11,6 @@ import {
import { RefundModal } from "./RefundModal";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
import {
Table,
@@ -23,28 +21,6 @@ import {
TableRow,
} from "@/components/__legacy__/ui/table";
function CoPilotUsageSection() {
const { data: usage, isLoading } = useUsageLimits();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>
<Button
className="w-full"
onClick={() => (window.location.href = "/copilot")}
>
Open CoPilot
</Button>
</div>
);
}
export default function CreditsPage() {
const api = useBackendAPI();
const {
@@ -261,13 +237,11 @@ export default function CreditsPage() {
</Button>
)}
</form>
{/* CoPilot Usage Limits */}
<CoPilotUsageSection />
</div>
<div className="my-6 space-y-4">
{/* Payment Portal */}
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
<p className="text-neutral-600">
You can manage your cards and see your payment history in the

View File

@@ -1382,28 +1382,6 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/usage": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8477,16 +8455,6 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"CoPilotUsageStatus": {
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
},
"type": "object",
"required": ["daily", "weekly"],
"title": "CoPilotUsageStatus",
"description": "Current usage status for a user across all windows."
},
"ContentType": {
"type": "string",
"enum": [
@@ -12222,16 +12190,6 @@
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",
"default": 0
},
"total_completion_tokens": {
"type": "integer",
"title": "Total Completion Tokens",
"default": 0
}
},
"type": "object",
@@ -14629,25 +14587,6 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",
"title": "Resets At"
}
},
"type": "object",
"required": ["used", "limit", "resets_at"],
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserHistoryResponse": {
"properties": {
"history": {

View File

@@ -288,7 +288,6 @@ const SidebarTrigger = React.forwardRef<
ref={ref}
data-sidebar="trigger"
variant="ghost"
size="icon"
onClick={(event) => {
onClick?.(event);
toggleSidebar();