mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
49 Commits
fix/block-
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9684b99949 | ||
|
|
d00059dc94 | ||
|
|
3a14077d52 | ||
|
|
4446be94ae | ||
|
|
983aed2b0a | ||
|
|
66a8cf69be | ||
|
|
d1b8766fa4 | ||
|
|
628b779128 | ||
|
|
90b7edf1f1 | ||
|
|
8b970c4c3d | ||
|
|
601fed93b8 | ||
|
|
e3f9fa3648 | ||
|
|
809ba56f0b | ||
|
|
9a467e1dba | ||
|
|
0200748225 | ||
|
|
1704214394 | ||
|
|
f2efd3ad7f | ||
|
|
ee841d1515 | ||
|
|
5966d3669d | ||
|
|
c81ab1fc3b | ||
|
|
5446c7f18f | ||
|
|
2b0c9ba703 | ||
|
|
195c7011ae | ||
|
|
d4944fb22b | ||
|
|
a5ed8fefa9 | ||
|
|
a52a777b29 | ||
|
|
8bec7a6933 | ||
|
|
e73791efed | ||
|
|
2d161ce2b9 | ||
|
|
6fc4989654 | ||
|
|
976443bf6e | ||
|
|
4ceb15b3f1 | ||
|
|
3096f94996 | ||
|
|
6f90729612 | ||
|
|
ebf89dde8b | ||
|
|
5d057e97e5 | ||
|
|
1d2f641a26 | ||
|
|
dcb71ab0b9 | ||
|
|
8136b90860 | ||
|
|
4d179a7c37 | ||
|
|
f78adcdc65 | ||
|
|
40388b7520 | ||
|
|
dd7be1158b | ||
|
|
c0e59f0a6b | ||
|
|
104d1f1bf4 | ||
|
|
d9e9cd4c98 | ||
|
|
ca416300ec | ||
|
|
c589cd0c43 | ||
|
|
b6d863fcd2 |
@@ -27,6 +27,12 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -389,6 +397,10 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -396,6 +408,26 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -496,6 +528,17 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based)
|
||||
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
@@ -251,6 +252,74 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -18,11 +18,13 @@ from langfuse import propagate_attributes
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.rate_limit import record_token_usage
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -36,6 +38,7 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
@@ -46,7 +49,11 @@ from backend.copilot.service import (
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -221,6 +228,9 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -232,6 +242,7 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
@@ -242,7 +253,18 @@ async def stream_chat_completion_baseline(
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
@@ -411,6 +433,53 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 0
|
||||
)
|
||||
turn_completion_tokens = max(
|
||||
estimate_token_count_str(assistant_text, model=config.model), 0
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Emit token usage and update session for persistence
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
total_tokens = turn_prompt_tokens + turn_completion_tokens
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
# Record for rate limiting counters
|
||||
if user_id:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to record token usage: %s", usage_err
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -421,4 +490,16 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,6 +70,20 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
|
||||
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
|
||||
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
|
||||
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
|
||||
@@ -73,6 +73,9 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
|
||||
253
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
253
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
|
||||
|
||||
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
|
||||
pushes to *next* Monday so the current week's window stays open.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
|
||||
"""Fetch daily and weekly token counters from Redis.
|
||||
|
||||
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
return int(daily_raw or 0), int(weekly_raw or 0)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for usage status, returning zeros", exc_info=True
|
||||
)
|
||||
daily_used, weekly_used = 0, 0
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for rate limit check, allowing request", exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
exc_info=True,
|
||||
)
|
||||
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -186,12 +186,29 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
totalTokens: int = Field(
|
||||
..., description="Total number of tokens (raw, not weighted)"
|
||||
)
|
||||
cacheReadTokens: int = Field(
|
||||
default=0, description="Prompt tokens served from cache (10% cost)"
|
||||
)
|
||||
cacheCreationTokens: int = Field(
|
||||
default=0, description="Prompt tokens written to cache (25% cost)"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
|
||||
@@ -198,6 +198,7 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@@ -0,0 +1,546 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns end events
|
||||
6. _read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import strip_progress_entries
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
|
||||
"""Test-only: read compacted entries from a session JSONL file.
|
||||
|
||||
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
|
||||
entry onward, or ``None`` if no summary is found.
|
||||
"""
|
||||
content = Path(path).read_text()
|
||||
lines = content.strip().split("\n")
|
||||
compact_idx: int | None = None
|
||||
parsed: list[dict] = []
|
||||
raw_lines: list[str] = []
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parsed.append(entry)
|
||||
raw_lines.append(line.strip())
|
||||
if compact_idx is None and entry.get("isCompactSummary"):
|
||||
compact_idx = len(parsed) - 1
|
||||
if compact_idx is None:
|
||||
return None
|
||||
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (end events)
|
||||
9. _read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
# on_compact is a property returning Event.set callable
|
||||
tracker.on_compact()
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
end_events = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events) == 2
|
||||
assert isinstance(end_events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(end_events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert end_events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: _read_compacted_entries + replace_entries ---
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
compacted_dicts, compacted_jsonl = result
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted_dicts) == 4
|
||||
assert compacted_dicts[0]["uuid"] == "cs1"
|
||||
assert compacted_dicts[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted JSONL
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events1) == 2 # output + finish
|
||||
|
||||
result1_entries = _read_compacted_entries(str(file1))
|
||||
assert result1_entries is not None
|
||||
_, compacted1_jsonl = result1_entries
|
||||
builder.replace_entries(compacted1_jsonl)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events2) == 2 # output + finish
|
||||
|
||||
result2_entries = _read_compacted_entries(str(file2))
|
||||
assert result2_entries is not None
|
||||
_, compacted2_jsonl = result2_entries
|
||||
builder.replace_entries(compacted2_jsonl)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
_, compacted_jsonl = result
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
@@ -221,12 +221,12 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -111,7 +111,9 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -169,7 +171,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -211,7 +213,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
|
||||
@@ -40,11 +40,13 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
from ..rate_limit import record_token_usage
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -54,6 +56,7 @@ from ..response_model import (
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
@@ -75,8 +78,12 @@ from .tool_adapter import (
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
TranscriptDownload,
|
||||
cleanup_cli_project_dir,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_cli_session_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -294,7 +301,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
|
||||
return
|
||||
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
@@ -388,7 +395,7 @@ async def _compress_messages(
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
logger.warning("[SDK] Context compression with LLM failed: %s", e)
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
@@ -624,6 +631,56 @@ async def _prepare_file_attachments(
|
||||
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
||||
|
||||
|
||||
async def _maybe_compact_and_upload(
|
||||
dl: TranscriptDownload,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str:
|
||||
"""Compact an oversized transcript and upload the compacted version.
|
||||
|
||||
Returns the (possibly compacted) transcript content, or an empty string
|
||||
if compaction was needed but failed.
|
||||
"""
|
||||
content = dl.content
|
||||
if len(content) <= COMPACT_THRESHOLD_BYTES:
|
||||
return content
|
||||
|
||||
logger.warning(
|
||||
"%s Transcript oversized (%dB > %dB), compacting",
|
||||
log_prefix,
|
||||
len(content),
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
)
|
||||
compacted = await compact_transcript(content, log_prefix=log_prefix)
|
||||
if not compacted:
|
||||
logger.warning(
|
||||
"%s Compaction failed, skipping resume for this turn", log_prefix
|
||||
)
|
||||
return ""
|
||||
|
||||
# Keep the original message_count: it reflects the number of
|
||||
# session.messages covered by this transcript, which the gap-fill
|
||||
# logic uses as a slice index. Counting JSONL lines would give a
|
||||
# smaller number (compacted messages != session message count) and
|
||||
# cause already-covered messages to be re-injected.
|
||||
try:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=compacted,
|
||||
message_count=dl.message_count,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s Failed to upload compacted transcript",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
return compacted
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -735,6 +792,14 @@ async def stream_chat_completion_sdk(
|
||||
_otel_ctx: Any = None
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
# Token usage accumulators — populated from ResultMessage at end of turn
|
||||
turn_prompt_tokens = 0 # uncached input tokens only
|
||||
turn_completion_tokens = 0
|
||||
turn_cache_read_tokens = 0
|
||||
turn_cache_creation_tokens = 0
|
||||
total_tokens = 0 # computed once before StreamUsage, reused in finally
|
||||
turn_cost_usd: float | None = None
|
||||
|
||||
try:
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||
# Pre-compute the cwd here so the exact working directory path can be
|
||||
@@ -827,20 +892,33 @@ async def stream_chat_completion_sdk(
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
transcript_content = await _maybe_compact_and_upload(
|
||||
dl,
|
||||
user_id=user_id or "",
|
||||
session_id=session_id,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
# Load previous context into builder (empty string is a no-op)
|
||||
if transcript_content:
|
||||
transcript_builder.load_previous(
|
||||
transcript_content, log_prefix=log_prefix
|
||||
)
|
||||
resume_file = (
|
||||
write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if transcript_content
|
||||
else None
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
||||
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"{log_prefix} No transcript available "
|
||||
@@ -1110,7 +1188,7 @@ async def stream_chat_completion_sdk(
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details for debugging
|
||||
# Log ResultMessage details and capture token usage
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
@@ -1129,9 +1207,46 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
turn_cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
turn_cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
turn_completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
if sdk_msg.total_cost_usd is not None:
|
||||
turn_cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# When compaction ends, sync TranscriptBuilder with
|
||||
# the CLI's compacted session file so the uploaded
|
||||
# transcript reflects compaction.
|
||||
compaction_events = await compaction.emit_end_if_ready(session)
|
||||
for ev in compaction_events:
|
||||
yield ev
|
||||
if compaction_events and sdk_cwd:
|
||||
cli_content = await read_cli_session_file(sdk_cwd)
|
||||
if cli_content:
|
||||
transcript_builder.replace_entries(
|
||||
cli_content, log_prefix=log_prefix
|
||||
)
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
@@ -1325,6 +1440,27 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# Emit token usage to the client (must be in try to reach SSE stream).
|
||||
# Session persistence of usage is in finally to stay consistent with
|
||||
# rate-limit recording even if an exception interrupts between here
|
||||
# and the finally block.
|
||||
# Compute total_tokens once; reused in the finally block for
|
||||
# session persistence and rate-limit recording.
|
||||
total_tokens = (
|
||||
turn_prompt_tokens
|
||||
+ turn_cache_read_tokens
|
||||
+ turn_cache_creation_tokens
|
||||
+ turn_completion_tokens
|
||||
)
|
||||
if total_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=total_tokens,
|
||||
cacheReadTokens=turn_cache_read_tokens,
|
||||
cacheCreationTokens=turn_cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
@@ -1389,6 +1525,48 @@ async def stream_chat_completion_sdk(
|
||||
except Exception:
|
||||
logger.warning("OTEL context teardown failed", exc_info=True)
|
||||
|
||||
# --- Persist token usage to session + rate-limit counters ---
|
||||
# Both must live in finally so they stay consistent even when an
|
||||
# exception interrupts the try block after StreamUsage was yielded.
|
||||
# total_tokens is computed once before StreamUsage yield above.
|
||||
if total_tokens > 0:
|
||||
if session is not None:
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
|
||||
"output=%d, total=%d, cost_usd=%s",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
turn_cost_usd,
|
||||
)
|
||||
if user_id and total_tokens > 0:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"%s Failed to record token usage: %s",
|
||||
log_prefix,
|
||||
usage_err,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
# This MUST run in finally to persist messages even when the generator
|
||||
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
||||
@@ -1484,6 +1662,6 @@ async def _update_title_async(
|
||||
)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||
|
||||
@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -13,10 +13,17 @@ filesystem for self-hosted) — no DB column needed.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +41,11 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
@@ -82,7 +94,11 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = entry.get("parentUuid", "")
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
if (
|
||||
entry.get("type", "") in STRIPPABLE_TYPES
|
||||
and uid
|
||||
and not entry.get("isCompactSummary")
|
||||
):
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
@@ -93,7 +109,9 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -106,7 +124,9 @@ def strip_progress_entries(content: str) -> str:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
@@ -137,32 +157,78 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return
|
||||
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
|
||||
return None
|
||||
return project_dir
|
||||
|
||||
|
||||
async def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any mid-stream compaction.
|
||||
|
||||
After the CLI compacts context, its session file contains the compacted
|
||||
conversation. Reading this file lets ``TranscriptBuilder`` replace its
|
||||
uncompacted entries with the CLI's compacted version.
|
||||
"""
|
||||
import aiofiles
|
||||
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Pick the most recently modified file (there should only be one per turn).
|
||||
# Guard against races where a file is deleted between glob and stat.
|
||||
candidates: list[tuple[float, Path]] = []
|
||||
for p in jsonl_files:
|
||||
try:
|
||||
candidates.append((p.stat().st_mtime, p))
|
||||
except OSError:
|
||||
continue
|
||||
if not candidates:
|
||||
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Resolve + prefix check to prevent symlink escapes.
|
||||
session_file = max(candidates, key=lambda item: item[0])[1]
|
||||
real_path = str(session_file.resolve())
|
||||
if not real_path.startswith(project_dir + os.sep):
|
||||
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
|
||||
return None
|
||||
try:
|
||||
async with aiofiles.open(real_path) as f:
|
||||
content = await f.read()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
real_path,
|
||||
len(content),
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory."""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
else:
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
@@ -180,7 +246,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -190,17 +256,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -344,11 +410,14 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -371,10 +440,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -394,10 +463,14 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, Exception):
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -405,15 +478,171 @@ async def download_transcript(
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
# Transcripts above this byte threshold are compacted at download time.
|
||||
COMPACT_THRESHOLD_BYTES = 400_000
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string."""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content", "")
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
text = sub.get("text")
|
||||
str_parts.append(
|
||||
str(text) if text is not None else json.dumps(sub)
|
||||
)
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to message dicts for compress_context."""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format."""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
cfg: ChatConfig,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback."""
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
|
||||
) as client:
|
||||
return await compress_context(messages=messages, model=model, client=client)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await compress_context(messages=messages, model=model, client=None)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
"""
|
||||
cfg = ChatConfig()
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
|
||||
if not result.was_compacted:
|
||||
logger.info("%s Transcript already within token budget", log_prefix)
|
||||
return content
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -31,6 +31,7 @@ class TranscriptEntry(BaseModel):
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
isCompactSummary: bool | None = None
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
@@ -78,10 +79,12 @@ class TranscriptBuilder:
|
||||
)
|
||||
continue
|
||||
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
|
||||
# Compaction summaries may have type "summary" but must be preserved
|
||||
# so --resume can reconstruct the compacted conversation.
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
is_compact = data.get("isCompactSummary", False)
|
||||
if entry_type in STRIPPABLE_TYPES and not is_compact:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
@@ -89,6 +92,7 @@ class TranscriptBuilder:
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
isCompactSummary=True if is_compact else None,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
@@ -177,6 +181,33 @@ class TranscriptBuilder:
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Replace all entries with compacted JSONL content.
|
||||
|
||||
Called after the CLI performs mid-stream compaction so the builder's
|
||||
state reflects the compacted conversation instead of the full
|
||||
pre-compaction history.
|
||||
"""
|
||||
prev_count = len(self._entries)
|
||||
temp = TranscriptBuilder()
|
||||
try:
|
||||
temp.load_previous(content, log_prefix=log_prefix)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s Failed to parse compacted transcript; keeping %d existing entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
)
|
||||
return
|
||||
self._entries = temp._entries
|
||||
self._last_uuid = temp._last_uuid
|
||||
logger.info(
|
||||
"%s Replaced %d entries with %d compacted entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
|
||||
@@ -2,14 +2,25 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
COMPACT_MSG_ID_PREFIX,
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
read_cli_session_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
@@ -35,6 +46,14 @@ PROGRESS_ENTRY = {
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"parentUuid": None,
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous conversation..."},
|
||||
}
|
||||
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
@@ -237,6 +256,121 @@ class TestStripProgressEntries:
|
||||
# Should return just a newline (empty content stripped)
|
||||
assert result.strip() == ""
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
result = _cli_project_dir("/tmp/copilot-abc")
|
||||
assert result is not None
|
||||
assert "projects" in result
|
||||
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
# A cwd that encodes to something with .. shouldn't escape
|
||||
result = _cli_project_dir("/tmp/copilot-test")
|
||||
# Should return a valid path (no traversal possible with alphanum encoding)
|
||||
assert result is None or result.startswith(str(projects))
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_session_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
# Create the CLI project directory structure
|
||||
cwd = "/tmp/copilot-testread"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# Write a session file
|
||||
session_file = project_dir / "test-session.jsonl"
|
||||
session_file.write_text(json.dumps(ASST_MSG) + "\n")
|
||||
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is not None
|
||||
assert "assistant" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
cwd = "/tmp/copilot-nofiles"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# No jsonl files
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
(tmp_path / "projects").mkdir()
|
||||
result = await read_cli_session_file("/tmp/copilot-nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _transcript_to_messages / _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestTranscriptMessageConversion:
|
||||
def test_roundtrip_preserves_roles(self):
|
||||
transcript = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
def test_messages_to_transcript_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result) is True
|
||||
|
||||
def test_strips_strippable_types(self):
|
||||
transcript = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2 # progress entry skipped
|
||||
|
||||
def test_flattens_assistant_content_blocks(self):
|
||||
asst_with_blocks = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "tool_use", "name": "bash"},
|
||||
],
|
||||
},
|
||||
}
|
||||
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
|
||||
assert len(messages) == 1
|
||||
assert "hello" in messages[0]["content"]
|
||||
assert "[tool_use: bash]" in messages[0]["content"]
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
result = _messages_to_transcript([])
|
||||
assert result == ""
|
||||
|
||||
def test_no_strippable_entries(self):
|
||||
"""When there's nothing to strip, output matches input structure."""
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
@@ -282,3 +416,654 @@ class TestStripProgressEntries:
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- TranscriptBuilder ---
|
||||
|
||||
|
||||
class TestTranscriptBuilderReplaceEntries:
|
||||
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
|
||||
|
||||
def test_replace_entries_with_valid_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Replace with compacted content (one user + one assistant)
|
||||
compacted = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_keeps_old_on_corrupt_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
# Corrupt content that fails to parse
|
||||
builder.replace_entries("not valid json at all\n")
|
||||
# Should still have old entries (load_previous skips invalid lines,
|
||||
# but if ALL lines are invalid, temp builder is empty → exception path)
|
||||
assert builder.entry_count >= 0 # doesn't crash
|
||||
|
||||
def test_replace_entries_with_empty_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
builder.replace_entries("")
|
||||
# Empty content → load_previous returns early → temp is empty
|
||||
# replace_entries swaps to empty (0 entries)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_replace_entries_filters_strippable_types(self):
|
||||
"""Strippable types (progress, file-history-snapshot) are filtered out."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
|
||||
content = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
builder.replace_entries(content)
|
||||
# Only user + assistant should remain (progress filtered)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_preserves_uuids(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
first = json.loads(lines[0])
|
||||
assert first["uuid"] == "u1"
|
||||
|
||||
|
||||
class TestTranscriptBuilderBasic:
|
||||
def test_append_user_and_assistant(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "hello"}])
|
||||
assert builder.entry_count == 2
|
||||
assert not builder.is_empty
|
||||
|
||||
def test_to_jsonl_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
assert builder.to_jsonl() == ""
|
||||
assert builder.is_empty
|
||||
|
||||
def test_load_previous_and_append(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2
|
||||
builder.append_user("new message")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_consecutive_assistant_entries_share_message_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "part1"}])
|
||||
builder.append_assistant([{"type": "text", "text": "part2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[2])
|
||||
assert asst1["message"]["id"] == asst2["message"]["id"]
|
||||
|
||||
def test_non_consecutive_assistant_entries_get_new_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "response1"}])
|
||||
builder.append_user("followup")
|
||||
builder.append_assistant([{"type": "text", "text": "response2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[3])
|
||||
assert asst1["message"]["id"] != asst2["message"]["id"]
|
||||
|
||||
|
||||
class TestCompactSummaryRoundtrip:
|
||||
"""Verify isCompactSummary survives export→reload roundtrip."""
|
||||
|
||||
def test_load_previous_preserves_compact_summary(self):
|
||||
"""Compaction summary with type 'summary' should not be stripped."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_export_reload_preserves_compact_summary(self):
|
||||
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder1 = TranscriptBuilder()
|
||||
builder1.load_previous(content)
|
||||
assert builder1.entry_count == 3
|
||||
|
||||
exported = builder1.to_jsonl()
|
||||
# Verify isCompactSummary is in the exported JSONL
|
||||
first_line = json.loads(exported.strip().split("\n")[0])
|
||||
assert first_line.get("isCompactSummary") is True
|
||||
|
||||
# Reload and verify it's still preserved
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(exported)
|
||||
assert builder2.entry_count == 3
|
||||
|
||||
def test_strip_progress_preserves_compact_summary(self):
|
||||
"""strip_progress_entries should keep isCompactSummary entries."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
|
||||
compact = [e for e in entries if e.get("isCompactSummary")]
|
||||
assert len(compact) == 1
|
||||
|
||||
def test_regular_summary_still_stripped(self):
|
||||
"""Non-compact summaries should still be stripped."""
|
||||
regular_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "rs1",
|
||||
"summary": "Session summary",
|
||||
}
|
||||
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" not in types
|
||||
|
||||
def test_replace_entries_preserves_compact_summary(self):
|
||||
"""replace_entries should preserve isCompactSummary entries."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("old")
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# Verify by re-exporting
|
||||
exported = builder.to_jsonl()
|
||||
first = json.loads(exported.strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
|
||||
|
||||
# --- _flatten_assistant_content ---
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: read]" in result
|
||||
|
||||
def test_string_blocks(self):
|
||||
"""Plain strings in the list should be included."""
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
def test_tool_use_missing_name(self):
|
||||
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
|
||||
|
||||
|
||||
# --- _flatten_tool_result_content ---
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "simple result"
|
||||
|
||||
def test_tool_result_with_nested_list(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [
|
||||
{"type": "text", "text": "line 1"},
|
||||
{"type": "text", "text": "line 2"},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
|
||||
|
||||
def test_text_blocks(self):
|
||||
blocks = [{"type": "text", "text": "some text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "some text"
|
||||
|
||||
def test_string_items(self):
|
||||
assert _flatten_tool_result_content(["raw string"]) == "raw string"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_tool_result_none_text_uses_json(self):
|
||||
"""Dicts without text key fall back to json.dumps."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback includes the key
|
||||
|
||||
|
||||
# --- _transcript_to_messages ---
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_conversation(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi there"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0] == {"role": "user", "content": "hello"}
|
||||
assert msgs[1] == {"role": "assistant", "content": "hi there"}
|
||||
|
||||
def test_strips_progress_entries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"message": {"role": "user", "content": "..."},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
|
||||
def test_preserves_compact_summaries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous..."},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of previous..."
|
||||
|
||||
def test_strips_regular_summary(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "s1",
|
||||
"message": {"role": "user", "content": "Session summary"},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["content"] == "hi"
|
||||
|
||||
def test_skips_entries_without_role(self):
|
||||
content = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {}},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
def test_tool_result_content(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "file contents",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert "file contents" in msgs[0]["content"]
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
assert _transcript_to_messages(" \n ") == []
|
||||
|
||||
def test_invalid_json_lines_skipped(self):
|
||||
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
|
||||
# --- _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_basic_roundtrip_structure(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert result.endswith("\n")
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert len(lines) == 2
|
||||
|
||||
# User entry
|
||||
assert lines[0]["type"] == "user"
|
||||
assert lines[0]["message"]["role"] == "user"
|
||||
assert lines[0]["message"]["content"] == "hello"
|
||||
assert lines[0]["parentUuid"] is None
|
||||
|
||||
# Assistant entry
|
||||
assert lines[1]["type"] == "assistant"
|
||||
assert lines[1]["message"]["role"] == "assistant"
|
||||
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
|
||||
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "q2"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert lines[0]["parentUuid"] is None
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
assert lines[2]["parentUuid"] == lines[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_assistant_empty_content(self):
|
||||
messages = [{"role": "assistant", "content": ""}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == []
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
|
||||
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
|
||||
|
||||
|
||||
class TestTranscriptCompactionRoundtrip:
|
||||
def test_content_preserved_through_roundtrip(self):
|
||||
"""Messages→transcript→messages preserves content."""
|
||||
original = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
recovered = _transcript_to_messages(transcript)
|
||||
assert len(recovered) == len(original)
|
||||
for orig, rec in zip(original, recovered):
|
||||
assert orig["role"] == rec["role"]
|
||||
assert orig["content"] == rec["content"]
|
||||
|
||||
def test_full_transcript_to_messages_and_back(self):
|
||||
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "explain python"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Python is a programming language."}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "output of ls",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs1 = _transcript_to_messages(source)
|
||||
assert len(msgs1) == 3
|
||||
|
||||
rebuilt = _messages_to_transcript(msgs1)
|
||||
msgs2 = _transcript_to_messages(rebuilt)
|
||||
assert len(msgs2) == len(msgs1)
|
||||
for m1, m2 in zip(msgs1, msgs2):
|
||||
assert m1["role"] == m2["role"]
|
||||
# Content may differ in format (list vs string) but text is preserved
|
||||
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
|
||||
|
||||
|
||||
# --- compact_transcript ---
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""Transcripts with < 2 messages can't be compacted."""
|
||||
single = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||
)
|
||||
result = await compact_transcript(single)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_returns_none(self):
|
||||
result = await compact_transcript("")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_produces_valid_transcript(self, monkeypatch):
|
||||
"""When compress_context compacts, result should be valid JSONL."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[
|
||||
{"role": "user", "content": "Summary of conversation"},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
],
|
||||
token_count=50,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "msg1"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "reply1"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "msg2"},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
# Verify compacted content
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compaction_needed_returns_original(self, monkeypatch):
|
||||
"""When compress_context says no compaction needed, return original."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[], token_count=100, was_compacted=False, original_token_count=100
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result == source # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_failure_returns_none(self, monkeypatch):
|
||||
"""When _run_compression raises, compact_transcript returns None."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is None
|
||||
|
||||
@@ -8,11 +8,15 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
@@ -21,6 +25,26 @@ from .utils import match_credentials_to_requirements
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().get_credits(user_id)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().spend_credits(
|
||||
user_id, cost, metadata
|
||||
)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
@@ -115,6 +139,20 @@ async def execute_block(
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Pre-execution credit check
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
if has_cost:
|
||||
balance = await _get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
@@ -123,6 +161,37 @@ async def execute_block(
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
logger.warning(
|
||||
"Post-exec credit charge failed for block %s (cost=%d)",
|
||||
block.name,
|
||||
cost,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to complete '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
@@ -133,16 +202,16 @@ async def execute_block(
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
message="An unexpected error occurred while executing the block",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,24 +1,202 @@
|
||||
"""Tests for execute_block type coercion in helpers.py.
|
||||
|
||||
Verifies that execute_block() coerces string input values to match the block's
|
||||
expected input types, mirroring the executor's validate_exec() logic.
|
||||
This is critical for @@agptfile: expansion, where file content is always a string
|
||||
but the block may expect structured types (e.g. list[list[str]]).
|
||||
"""
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
async def _execute(
|
||||
input_data: dict, **kwargs: Any
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
yield "result", "ok"
|
||||
|
||||
mock.execute = _execute
|
||||
return mock
|
||||
|
||||
|
||||
def _patch_workspace():
|
||||
"""Patch workspace_db to return a mock workspace."""
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_ws_db = MagicMock()
|
||||
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
mock_spend = AsyncMock()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=100,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=mock_spend,
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_spend.assert_awaited_once()
|
||||
call_kwargs = mock_spend.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=5, # balance < cost (10)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
) as mock_get_credits,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
) as mock_spend_credits,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_get_credits.assert_not_awaited()
|
||||
mock_spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_error_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, return ErrorResponse."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=15, # passes pre-check
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
), # fails during actual charge (race with concurrent spend)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
"""Create a mock input_schema with model_fields matching the given annotations."""
|
||||
schema = MagicMock()
|
||||
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
|
||||
model_fields = {}
|
||||
for name, ann in annotations.items():
|
||||
field = MagicMock()
|
||||
@@ -28,7 +206,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
return schema
|
||||
|
||||
|
||||
def _make_block(
|
||||
def _make_coerce_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
@@ -60,7 +238,7 @@ _TEST_USER_ID = "test-user-coerce"
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
@@ -90,7 +268,6 @@ async def test_coerce_json_string_to_nested_list():
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
# Verify the input was coerced from string to list[list[str]]
|
||||
assert block._captured_inputs["values"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
@@ -103,7 +280,7 @@ async def test_coerce_json_string_to_nested_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
@@ -135,7 +312,7 @@ async def test_coerce_json_string_to_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
@@ -167,7 +344,7 @@ async def test_coerce_json_string_to_dict():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
@@ -201,7 +378,7 @@ async def test_no_coercion_when_type_matches():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
@@ -234,7 +411,7 @@ async def test_coerce_string_to_int():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
@@ -260,14 +437,13 @@ async def test_coerce_skips_none_values():
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# 'data' was not provided, so it should not appear in captured inputs
|
||||
assert "data" not in block._captured_inputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
@@ -293,7 +469,6 @@ async def test_coerce_union_type_preserves_valid_member():
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# list[str] should NOT be stringified to '["a", "b"]'
|
||||
assert block._captured_inputs["content"] == ["a", "b"]
|
||||
assert isinstance(block._captured_inputs["content"], list)
|
||||
|
||||
@@ -301,7 +476,7 @@ async def test_coerce_union_type_preserves_valid_member():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
@@ -319,7 +494,6 @@ async def test_coerce_inner_elements_of_generic():
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="inner-coerce",
|
||||
# Inner elements are ints, but target is list[str]
|
||||
input_data={"values": [1, 2, 3]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
@@ -328,6 +502,5 @@ async def test_coerce_inner_elements_of_generic():
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# Inner elements should be coerced from int to str
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
|
||||
@@ -512,6 +512,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Credits ============ #
|
||||
spend_credits = d.spend_credits
|
||||
get_credits = d.get_credits
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
@@ -70,6 +70,9 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens
|
||||
tool_call_tokens += _tok_len(item.get("text", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -145,10 +148,14 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(mid)
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -396,7 +403,7 @@ def validate_and_remove_orphan_tool_responses(
|
||||
|
||||
if log_warning:
|
||||
logger.warning(
|
||||
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
||||
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
|
||||
)
|
||||
|
||||
return _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
@@ -488,8 +495,9 @@ def _ensure_tool_pairs_intact(
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
"Could not find assistant messages for tool_call_ids: %s. "
|
||||
"Removing orphan tool responses.",
|
||||
orphan_tool_call_ids,
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
@@ -497,8 +505,8 @@ def _ensure_tool_pairs_intact(
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
|
||||
len(messages_to_prepend),
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
@@ -686,11 +694,15 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
@@ -728,6 +740,12 @@ async def compress_context(
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
# Count assistant messages to ensure we keep at least one
|
||||
assistant_indices: set[int] = {
|
||||
i
|
||||
for i in range(len(msgs))
|
||||
if msgs[i] is not None and msgs[i].get("role") == "assistant"
|
||||
}
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
@@ -735,6 +753,9 @@ async def compress_context(
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
# Skip if this is the last remaining assistant message
|
||||
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
|
||||
continue
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
@@ -92,7 +86,6 @@ export function CopilotPage() {
|
||||
// Delete functionality
|
||||
sessionToDelete,
|
||||
isDeleting,
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
} = useCopilotPage();
|
||||
@@ -148,38 +141,6 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1.5 hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
const session = sessions.find(
|
||||
(s) => s.id === sessionId,
|
||||
);
|
||||
if (session) {
|
||||
handleDeleteClick(session.id, session.title);
|
||||
}
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { LayoutGroup, motion } from "framer-motion";
|
||||
import { ReactNode } from "react";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { EmptySession } from "../EmptySession/EmptySession";
|
||||
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -40,7 +38,6 @@ export const ChatContainer = ({
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
@@ -63,7 +60,6 @@ export const ChatContainer = ({
|
||||
status={status}
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
|
||||
@@ -30,7 +30,6 @@ interface Props {
|
||||
status: string;
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
|
||||
status,
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
{headerSlot}
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div
|
||||
className="flex flex-1 items-center justify-center"
|
||||
|
||||
@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
@@ -256,11 +257,10 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-5 flex items-center gap-1">
|
||||
<div className="flex items-center">
|
||||
<UsageLimits />
|
||||
<NotificationToggle />
|
||||
<div className="relative left-1">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
{sessionId ? (
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
@@ -48,10 +49,7 @@ export function NotificationToggle() {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
||||
aria-label="Notification settings"
|
||||
>
|
||||
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
@@ -59,7 +57,7 @@ export function NotificationToggle() {
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</button>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChartBar } from "@phosphor-icons/react";
|
||||
import { useUsageLimits } from "./useUsageLimits";
|
||||
|
||||
const MS_PER_MINUTE = 60_000;
|
||||
const MS_PER_HOUR = 3_600_000;
|
||||
|
||||
function formatResetTime(resetsAt: Date | string): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const now = new Date();
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / MS_PER_HOUR);
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showBillingLink = true,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
showBillingLink?: boolean;
|
||||
}) {
|
||||
const hasDailyLimit = usage.daily.limit > 0;
|
||||
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||
|
||||
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||
return (
|
||||
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||
{hasDailyLimit && (
|
||||
<UsageBar
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{hasWeeklyLimit && (
|
||||
<UsageBar
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
{showBillingLink && (
|
||||
<a
|
||||
href="/profile/credits"
|
||||
className="text-[11px] text-blue-600 hover:underline"
|
||||
>
|
||||
Learn more about usage limits
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsageLimits() {
|
||||
const { data: usage, isLoading } = useUsageLimits();
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||
<ChartBar className="!size-5" weight="light" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-64 p-3">
|
||||
<UsagePanelContent usage={usage} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsageLimits } from "../UsageLimits";
|
||||
|
||||
// Mock the useUsageLimits hook
|
||||
const mockUseUsageLimits = vi.fn();
|
||||
vi.mock("../useUsageLimits", () => ({
|
||||
useUsageLimits: () => mockUseUsageLimits(),
|
||||
}));
|
||||
|
||||
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseUsageLimits.mockReset();
|
||||
});
|
||||
|
||||
function makeUsage({
|
||||
dailyUsed = 500,
|
||||
dailyLimit = 10000,
|
||||
weeklyUsed = 2000,
|
||||
weeklyLimit = 50000,
|
||||
}: {
|
||||
dailyUsed?: number;
|
||||
dailyLimit?: number;
|
||||
weeklyUsed?: number;
|
||||
weeklyLimit?: number;
|
||||
} = {}) {
|
||||
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||
return {
|
||||
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||
};
|
||||
}
|
||||
|
||||
describe("UsageLimits", () => {
|
||||
it("renders nothing while loading", () => {
|
||||
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders nothing when no limits are configured", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||
isLoading: false,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders the usage button when limits exist", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays daily and weekly usage percentages", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("50% used")).toBeDefined();
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows only weekly bar when daily limit is 0", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({
|
||||
dailyLimit: 0,
|
||||
weeklyUsed: 25000,
|
||||
weeklyLimit: 50000,
|
||||
}),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
});
|
||||
|
||||
it("caps percentage at 100% when over limit", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows learn more link to credits page", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
const link = screen.getByText("Learn more about usage limits");
|
||||
expect(link).toBeDefined();
|
||||
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,12 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
export function useUsageLimits() {
|
||||
return useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
getGetV2GetCopilotUsageQueryKey,
|
||||
getGetV2GetSessionQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -307,6 +308,9 @@ export function useCopilotStream({
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetCopilotUsageQueryKey(),
|
||||
});
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
|
||||
@@ -11,6 +11,8 @@ import {
|
||||
|
||||
import { RefundModal } from "./RefundModal";
|
||||
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
||||
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
|
||||
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
|
||||
|
||||
import {
|
||||
Table,
|
||||
@@ -21,6 +23,26 @@ import {
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
function CoPilotUsageSection() {
|
||||
const { data: usage, isLoading } = useUsageLimits();
|
||||
const router = useRouter();
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<div className="my-6 space-y-4">
|
||||
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
|
||||
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
|
||||
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||
</div>
|
||||
<Button className="w-full" onClick={() => router.push("/copilot")}>
|
||||
Open CoPilot
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CreditsPage() {
|
||||
const api = useBackendAPI();
|
||||
const {
|
||||
@@ -237,11 +259,13 @@ export default function CreditsPage() {
|
||||
</Button>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* CoPilot Usage Limits */}
|
||||
<CoPilotUsageSection />
|
||||
</div>
|
||||
|
||||
<div className="my-6 space-y-4">
|
||||
{/* Payment Portal */}
|
||||
|
||||
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
||||
<p className="text-neutral-600">
|
||||
You can manage your cards and see your payment history in the
|
||||
|
||||
@@ -1382,6 +1382,28 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/usage": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Copilot Usage",
|
||||
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
|
||||
"operationId": "getV2GetCopilotUsage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -8455,6 +8477,16 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"CoPilotUsageStatus": {
|
||||
"properties": {
|
||||
"daily": { "$ref": "#/components/schemas/UsageWindow" },
|
||||
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["daily", "weekly"],
|
||||
"title": "CoPilotUsageStatus",
|
||||
"description": "Current usage status for a user across all windows."
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -12190,6 +12222,16 @@
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
},
|
||||
"total_prompt_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Prompt Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"total_completion_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Completion Tokens",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -14587,6 +14629,25 @@
|
||||
"required": ["timezone"],
|
||||
"title": "UpdateTimezoneRequest"
|
||||
},
|
||||
"UsageWindow": {
|
||||
"properties": {
|
||||
"used": { "type": "integer", "title": "Used" },
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"title": "Limit",
|
||||
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||
},
|
||||
"resets_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Resets At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["used", "limit", "resets_at"],
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
|
||||
ref={ref}
|
||||
data-sidebar="trigger"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={(event) => {
|
||||
onClick?.(event);
|
||||
toggleSidebar();
|
||||
|
||||
Reference in New Issue
Block a user