mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
49 Commits
fix/copilo
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9684b99949 | ||
|
|
d00059dc94 | ||
|
|
3a14077d52 | ||
|
|
4446be94ae | ||
|
|
983aed2b0a | ||
|
|
66a8cf69be | ||
|
|
d1b8766fa4 | ||
|
|
628b779128 | ||
|
|
90b7edf1f1 | ||
|
|
8b970c4c3d | ||
|
|
601fed93b8 | ||
|
|
e3f9fa3648 | ||
|
|
809ba56f0b | ||
|
|
9a467e1dba | ||
|
|
0200748225 | ||
|
|
1704214394 | ||
|
|
f2efd3ad7f | ||
|
|
ee841d1515 | ||
|
|
5966d3669d | ||
|
|
c81ab1fc3b | ||
|
|
5446c7f18f | ||
|
|
2b0c9ba703 | ||
|
|
195c7011ae | ||
|
|
d4944fb22b | ||
|
|
a5ed8fefa9 | ||
|
|
a52a777b29 | ||
|
|
8bec7a6933 | ||
|
|
e73791efed | ||
|
|
2d161ce2b9 | ||
|
|
6fc4989654 | ||
|
|
976443bf6e | ||
|
|
4ceb15b3f1 | ||
|
|
3096f94996 | ||
|
|
6f90729612 | ||
|
|
ebf89dde8b | ||
|
|
5d057e97e5 | ||
|
|
1d2f641a26 | ||
|
|
dcb71ab0b9 | ||
|
|
8136b90860 | ||
|
|
4d179a7c37 | ||
|
|
f78adcdc65 | ||
|
|
40388b7520 | ||
|
|
dd7be1158b | ||
|
|
c0e59f0a6b | ||
|
|
104d1f1bf4 | ||
|
|
d9e9cd4c98 | ||
|
|
ca416300ec | ||
|
|
c589cd0c43 | ||
|
|
b6d863fcd2 |
@@ -27,6 +27,12 @@ from backend.copilot.model import (
|
|||||||
get_user_sessions,
|
get_user_sessions,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
)
|
)
|
||||||
|
from backend.copilot.rate_limit import (
|
||||||
|
CoPilotUsageStatus,
|
||||||
|
RateLimitExceeded,
|
||||||
|
check_rate_limit,
|
||||||
|
get_usage_status,
|
||||||
|
)
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||||
from backend.copilot.tools.models import (
|
from backend.copilot.tools.models import (
|
||||||
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||||
|
total_prompt_tokens: int = 0
|
||||||
|
total_completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -389,6 +397,10 @@ async def get_session(
|
|||||||
last_message_id=last_message_id,
|
last_message_id=last_message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sum token usage from session
|
||||||
|
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||||
|
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
@@ -396,6 +408,26 @@ async def get_session(
|
|||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
active_stream=active_stream_info,
|
active_stream=active_stream_info,
|
||||||
|
total_prompt_tokens=total_prompt,
|
||||||
|
total_completion_tokens=total_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/usage")
|
||||||
|
async def get_copilot_usage(
|
||||||
|
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||||
|
) -> CoPilotUsageStatus:
|
||||||
|
"""Get CoPilot usage status for the authenticated user.
|
||||||
|
|
||||||
|
Returns current token usage vs limits for daily and weekly windows.
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
|
||||||
|
return await get_usage_status(
|
||||||
|
user_id=user_id,
|
||||||
|
daily_token_limit=config.daily_token_limit,
|
||||||
|
weekly_token_limit=config.weekly_token_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -496,6 +528,17 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pre-turn rate limit check (token-based)
|
||||||
|
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
|
||||||
|
try:
|
||||||
|
await check_rate_limit(
|
||||||
|
user_id=user_id,
|
||||||
|
daily_token_limit=config.daily_token_limit,
|
||||||
|
weekly_token_limit=config.weekly_token_limit,
|
||||||
|
)
|
||||||
|
except RateLimitExceeded as e:
|
||||||
|
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||||
|
|
||||||
# Enrich message with file metadata if file_ids are provided.
|
# Enrich message with file metadata if file_ids are provided.
|
||||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -251,6 +252,74 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
|||||||
assert call_kwargs["where"]["isDeleted"] is False
|
assert call_kwargs["where"]["isDeleted"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_usage(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
*,
|
||||||
|
daily_used: int = 500,
|
||||||
|
weekly_used: int = 2000,
|
||||||
|
) -> AsyncMock:
|
||||||
|
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||||
|
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||||
|
|
||||||
|
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||||
|
status = CoPilotUsageStatus(
|
||||||
|
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||||
|
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||||
|
)
|
||||||
|
return mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_usage_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_returns_daily_and_weekly(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""GET /usage returns daily and weekly usage."""
|
||||||
|
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||||
|
|
||||||
|
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||||
|
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||||
|
|
||||||
|
response = client.get("/usage")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["daily"]["used"] == 500
|
||||||
|
assert data["weekly"]["used"] == 2000
|
||||||
|
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
user_id=test_user_id,
|
||||||
|
daily_token_limit=10000,
|
||||||
|
weekly_token_limit=50000,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_uses_config_limits(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||||
|
mock_get = _mock_usage(mocker)
|
||||||
|
|
||||||
|
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||||
|
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||||
|
|
||||||
|
response = client.get("/usage")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
user_id=test_user_id,
|
||||||
|
daily_token_limit=99999,
|
||||||
|
weekly_token_limit=77777,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ from langfuse import propagate_attributes
|
|||||||
from backend.copilot.model import (
|
from backend.copilot.model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
|
Usage,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
from backend.copilot.prompting import get_baseline_supplement
|
from backend.copilot.prompting import get_baseline_supplement
|
||||||
|
from backend.copilot.rate_limit import record_token_usage
|
||||||
from backend.copilot.response_model import (
|
from backend.copilot.response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
@@ -36,6 +38,7 @@ from backend.copilot.response_model import (
|
|||||||
StreamToolInputAvailable,
|
StreamToolInputAvailable,
|
||||||
StreamToolInputStart,
|
StreamToolInputStart,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from backend.copilot.service import (
|
from backend.copilot.service import (
|
||||||
_build_system_prompt,
|
_build_system_prompt,
|
||||||
@@ -46,7 +49,11 @@ from backend.copilot.service import (
|
|||||||
from backend.copilot.tools import execute_tool, get_available_tools
|
from backend.copilot.tools import execute_tool, get_available_tools
|
||||||
from backend.copilot.tracking import track_user_message
|
from backend.copilot.tracking import track_user_message
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.prompt import compress_context
|
from backend.util.prompt import (
|
||||||
|
compress_context,
|
||||||
|
estimate_token_count,
|
||||||
|
estimate_token_count_str,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -221,6 +228,9 @@ async def stream_chat_completion_baseline(
|
|||||||
text_block_id = str(uuid.uuid4())
|
text_block_id = str(uuid.uuid4())
|
||||||
text_started = False
|
text_started = False
|
||||||
step_open = False
|
step_open = False
|
||||||
|
# Token usage accumulators — populated from streaming chunks
|
||||||
|
turn_prompt_tokens = 0
|
||||||
|
turn_completion_tokens = 0
|
||||||
try:
|
try:
|
||||||
for _round in range(_MAX_TOOL_ROUNDS):
|
for _round in range(_MAX_TOOL_ROUNDS):
|
||||||
# Open a new step for each LLM round
|
# Open a new step for each LLM round
|
||||||
@@ -232,6 +242,7 @@ async def stream_chat_completion_baseline(
|
|||||||
model=config.model,
|
model=config.model,
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
)
|
)
|
||||||
if tools:
|
if tools:
|
||||||
create_kwargs["tools"] = tools
|
create_kwargs["tools"] = tools
|
||||||
@@ -242,7 +253,18 @@ async def stream_chat_completion_baseline(
|
|||||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
delta = chunk.choices[0].delta if chunk.choices else None
|
# Capture token usage from the streaming chunk.
|
||||||
|
# OpenRouter normalises all providers into OpenAI format
|
||||||
|
# where prompt_tokens already includes cached tokens
|
||||||
|
# (unlike Anthropic's native API). Use += to sum all
|
||||||
|
# tool-call rounds since each API call is independent.
|
||||||
|
if chunk.usage:
|
||||||
|
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||||
|
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||||
|
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
if not delta:
|
if not delta:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -411,6 +433,53 @@ async def stream_chat_completion_baseline(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||||
|
|
||||||
|
# Fallback: estimate tokens via tiktoken when the provider does
|
||||||
|
# not honour stream_options={"include_usage": True}.
|
||||||
|
# Count the full message list (system + history + turn) since
|
||||||
|
# each API call sends the complete context window.
|
||||||
|
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
|
||||||
|
turn_prompt_tokens = max(
|
||||||
|
estimate_token_count(openai_messages, model=config.model), 0
|
||||||
|
)
|
||||||
|
turn_completion_tokens = max(
|
||||||
|
estimate_token_count_str(assistant_text, model=config.model), 0
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||||
|
"prompt=%d, completion=%d",
|
||||||
|
turn_prompt_tokens,
|
||||||
|
turn_completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit token usage and update session for persistence
|
||||||
|
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||||
|
total_tokens = turn_prompt_tokens + turn_completion_tokens
|
||||||
|
session.usage.append(
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=turn_prompt_tokens,
|
||||||
|
completion_tokens=turn_completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
|
||||||
|
turn_prompt_tokens,
|
||||||
|
turn_completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
)
|
||||||
|
# Record for rate limiting counters
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
await record_token_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
prompt_tokens=turn_prompt_tokens,
|
||||||
|
completion_tokens=turn_completion_tokens,
|
||||||
|
)
|
||||||
|
except Exception as usage_err:
|
||||||
|
logger.warning(
|
||||||
|
"[Baseline] Failed to record token usage: %s", usage_err
|
||||||
|
)
|
||||||
|
|
||||||
# Persist assistant response
|
# Persist assistant response
|
||||||
if assistant_text:
|
if assistant_text:
|
||||||
session.messages.append(
|
session.messages.append(
|
||||||
@@ -421,4 +490,16 @@ async def stream_chat_completion_baseline(
|
|||||||
except Exception as persist_err:
|
except Exception as persist_err:
|
||||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||||
|
|
||||||
|
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||||
|
# PEP 525 prohibits yielding from finally in async generators during
|
||||||
|
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||||
|
# On GeneratorExit the client is already gone, so unreachable yields
|
||||||
|
# are harmless; on normal completion they reach the SSE stream.
|
||||||
|
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||||
|
yield StreamUsage(
|
||||||
|
promptTokens=turn_prompt_tokens,
|
||||||
|
completionTokens=turn_completion_tokens,
|
||||||
|
totalTokens=turn_prompt_tokens + turn_completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
|
|||||||
@@ -70,6 +70,20 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Rate limiting — token-based limits per day and per week.
|
||||||
|
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
|
||||||
|
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
|
||||||
|
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
|
||||||
|
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
|
||||||
|
daily_token_limit: int = Field(
|
||||||
|
default=2_500_000,
|
||||||
|
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||||
|
)
|
||||||
|
weekly_token_limit: int = Field(
|
||||||
|
default=12_500_000,
|
||||||
|
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||||
|
)
|
||||||
|
|
||||||
# Claude Agent SDK Configuration
|
# Claude Agent SDK Configuration
|
||||||
use_claude_agent_sdk: bool = Field(
|
use_claude_agent_sdk: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ class Usage(BaseModel):
|
|||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||||
|
cache_read_tokens: int = 0
|
||||||
|
cache_creation_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionInfo(BaseModel):
|
class ChatSessionInfo(BaseModel):
|
||||||
|
|||||||
253
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
253
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""CoPilot rate limiting based on token usage.
|
||||||
|
|
||||||
|
Uses Redis fixed-window counters to track per-user token consumption
|
||||||
|
with configurable daily and weekly limits. Daily windows reset at
|
||||||
|
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||||
|
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Redis key prefixes
|
||||||
|
_PREFIX = "copilot:usage"
|
||||||
|
|
||||||
|
|
||||||
|
class UsageWindow(BaseModel):
|
||||||
|
"""Usage within a single time window."""
|
||||||
|
|
||||||
|
used: int
|
||||||
|
limit: int = Field(
|
||||||
|
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||||
|
)
|
||||||
|
resets_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CoPilotUsageStatus(BaseModel):
|
||||||
|
"""Current usage status for a user across all windows."""
|
||||||
|
|
||||||
|
daily: UsageWindow
|
||||||
|
weekly: UsageWindow
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceeded(Exception):
|
||||||
|
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||||
|
|
||||||
|
def __init__(self, window: str, resets_at: datetime):
|
||||||
|
self.window = window
|
||||||
|
self.resets_at = resets_at
|
||||||
|
delta = resets_at - datetime.now(UTC)
|
||||||
|
total_secs = delta.total_seconds()
|
||||||
|
if total_secs <= 0:
|
||||||
|
time_str = "now"
|
||||||
|
else:
|
||||||
|
hours = int(total_secs // 3600)
|
||||||
|
minutes = int((total_secs % 3600) // 60)
|
||||||
|
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||||
|
super().__init__(
|
||||||
|
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||||
|
|
||||||
|
|
||||||
|
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
year, week, _ = now.isocalendar()
|
||||||
|
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||||
|
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||||
|
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
|
||||||
|
|
||||||
|
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
|
||||||
|
pushes to *next* Monday so the current week's window stays open.
|
||||||
|
"""
|
||||||
|
if now is None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||||
|
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||||
|
days=days_until_monday
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
|
||||||
|
"""Fetch daily and weekly token counters from Redis.
|
||||||
|
|
||||||
|
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
daily_raw, weekly_raw = await asyncio.gather(
|
||||||
|
redis.get(_daily_key(user_id, now=now)),
|
||||||
|
redis.get(_weekly_key(user_id, now=now)),
|
||||||
|
)
|
||||||
|
return int(daily_raw or 0), int(weekly_raw or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_usage_status(
|
||||||
|
user_id: str,
|
||||||
|
daily_token_limit: int,
|
||||||
|
weekly_token_limit: int,
|
||||||
|
) -> CoPilotUsageStatus:
|
||||||
|
"""Get current usage status for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID.
|
||||||
|
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||||
|
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CoPilotUsageStatus with current usage and limits.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
try:
|
||||||
|
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Redis unavailable for usage status, returning zeros", exc_info=True
|
||||||
|
)
|
||||||
|
daily_used, weekly_used = 0, 0
|
||||||
|
|
||||||
|
return CoPilotUsageStatus(
|
||||||
|
daily=UsageWindow(
|
||||||
|
used=daily_used,
|
||||||
|
limit=daily_token_limit,
|
||||||
|
resets_at=_daily_reset_time(now=now),
|
||||||
|
),
|
||||||
|
weekly=UsageWindow(
|
||||||
|
used=weekly_used,
|
||||||
|
limit=weekly_token_limit,
|
||||||
|
resets_at=_weekly_reset_time(now=now),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_rate_limit(
|
||||||
|
user_id: str,
|
||||||
|
daily_token_limit: int,
|
||||||
|
weekly_token_limit: int,
|
||||||
|
) -> None:
|
||||||
|
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||||
|
|
||||||
|
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||||
|
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||||
|
two parallel turns may both pass this check against the same snapshot.
|
||||||
|
This is acceptable because token-based limits are approximate by nature
|
||||||
|
(the exact token count is unknown until after generation).
|
||||||
|
|
||||||
|
Fails open: if Redis is unavailable, allows the request.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
try:
|
||||||
|
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Redis unavailable for rate limit check, allowing request", exc_info=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||||
|
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||||
|
|
||||||
|
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||||
|
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||||
|
|
||||||
|
|
||||||
|
async def record_token_usage(
|
||||||
|
user_id: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
*,
|
||||||
|
cache_read_tokens: int = 0,
|
||||||
|
cache_creation_tokens: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Record token usage for a user across all windows.
|
||||||
|
|
||||||
|
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||||
|
multi-turn conversations. Anthropic's pricing:
|
||||||
|
- uncached input: 100%
|
||||||
|
- cache creation: 25%
|
||||||
|
- cache read: 10%
|
||||||
|
- output: 100%
|
||||||
|
|
||||||
|
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||||
|
from the API response). Cache counts are passed separately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID.
|
||||||
|
prompt_tokens: Uncached input tokens.
|
||||||
|
completion_tokens: Output tokens.
|
||||||
|
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||||
|
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||||
|
"""
|
||||||
|
weighted_input = (
|
||||||
|
prompt_tokens
|
||||||
|
+ round(cache_creation_tokens * 0.25)
|
||||||
|
+ round(cache_read_tokens * 0.1)
|
||||||
|
)
|
||||||
|
total = weighted_input + completion_tokens
|
||||||
|
if total <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_total = (
|
||||||
|
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||||
|
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||||
|
user_id[:8],
|
||||||
|
raw_total,
|
||||||
|
total,
|
||||||
|
prompt_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
pipe = redis.pipeline(transaction=False)
|
||||||
|
|
||||||
|
# Daily counter (expires at next midnight UTC)
|
||||||
|
d_key = _daily_key(user_id, now=now)
|
||||||
|
pipe.incrby(d_key, total)
|
||||||
|
seconds_until_daily_reset = int(
|
||||||
|
(_daily_reset_time(now=now) - now).total_seconds()
|
||||||
|
)
|
||||||
|
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||||
|
|
||||||
|
# Weekly counter (expires end of week)
|
||||||
|
w_key = _weekly_key(user_id, now=now)
|
||||||
|
pipe.incrby(w_key, total)
|
||||||
|
seconds_until_weekly_reset = int(
|
||||||
|
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||||
|
)
|
||||||
|
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||||
|
|
||||||
|
await pipe.execute()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Redis unavailable for recording token usage (tokens=%d)",
|
||||||
|
total,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""Unit tests for CoPilot rate limiting."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from .rate_limit import (
|
||||||
|
CoPilotUsageStatus,
|
||||||
|
RateLimitExceeded,
|
||||||
|
check_rate_limit,
|
||||||
|
get_usage_status,
|
||||||
|
record_token_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
_USER = "test-user-rl"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RateLimitExceeded
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitExceeded:
|
||||||
|
def test_message_contains_window_name(self):
|
||||||
|
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||||
|
assert "daily" in str(exc)
|
||||||
|
|
||||||
|
def test_message_contains_reset_time(self):
|
||||||
|
exc = RateLimitExceeded(
|
||||||
|
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||||
|
)
|
||||||
|
msg = str(exc)
|
||||||
|
# Allow for slight timing drift (29m or 30m)
|
||||||
|
assert "2h " in msg
|
||||||
|
assert "Resets in" in msg
|
||||||
|
|
||||||
|
def test_message_minutes_only_when_under_one_hour(self):
|
||||||
|
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||||
|
msg = str(exc)
|
||||||
|
assert "Resets in" in msg
|
||||||
|
# Should not have "0h"
|
||||||
|
assert "0h" not in msg
|
||||||
|
|
||||||
|
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||||
|
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||||
|
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||||
|
assert "Resets in now" in str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_usage_status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUsageStatus:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_redis_values(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
status = await get_usage_status(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(status, CoPilotUsageStatus)
|
||||||
|
assert status.daily.used == 500
|
||||||
|
assert status.daily.limit == 10000
|
||||||
|
assert status.weekly.used == 2000
|
||||||
|
assert status.weekly.limit == 50000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_zeros_when_redis_unavailable(self):
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
side_effect=ConnectionError("Redis down"),
|
||||||
|
):
|
||||||
|
status = await get_usage_status(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status.daily.used == 0
|
||||||
|
assert status.weekly.used == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_none_daily_counter(self):
|
||||||
|
"""Daily counter is None (new day), weekly has usage."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
status = await get_usage_status(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status.daily.used == 0
|
||||||
|
assert status.weekly.used == 3000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_none_weekly_counter(self):
|
||||||
|
"""Weekly counter is None (start of week), daily has usage."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
status = await get_usage_status(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status.daily.used == 500
|
||||||
|
assert status.weekly.used == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
status = await get_usage_status(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
# Daily reset should be within 24h
|
||||||
|
assert status.daily.resets_at > now
|
||||||
|
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_rate_limit
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckRateLimit:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_when_under_limit(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await check_rate_limit(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_when_daily_limit_exceeded(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||||
|
await check_rate_limit(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
assert exc_info.value.window == "daily"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_when_weekly_limit_exceeded(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||||
|
await check_rate_limit(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
assert exc_info.value.window == "weekly"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_when_redis_unavailable(self):
|
||||||
|
"""Fail-open: allow requests when Redis is down."""
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
side_effect=ConnectionError("Redis down"),
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await check_rate_limit(
|
||||||
|
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_check_when_limit_is_zero(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
# Should not raise — limits of 0 mean unlimited
|
||||||
|
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# record_token_usage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecordTokenUsage:
|
||||||
|
@staticmethod
|
||||||
|
def _make_pipeline_mock() -> MagicMock:
|
||||||
|
"""Create a pipeline mock with sync methods and async execute."""
|
||||||
|
pipe = MagicMock()
|
||||||
|
pipe.execute = AsyncMock(return_value=[])
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_increments_redis_counters(self):
|
||||||
|
mock_pipe = self._make_pipeline_mock()
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||||
|
|
||||||
|
# Should call incrby twice (daily + weekly) with total=150
|
||||||
|
incrby_calls = mock_pipe.incrby.call_args_list
|
||||||
|
assert len(incrby_calls) == 2
|
||||||
|
assert incrby_calls[0].args[1] == 150 # daily
|
||||||
|
assert incrby_calls[1].args[1] == 150 # weekly
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_when_zero_tokens(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||||
|
|
||||||
|
# Should not call pipeline at all
|
||||||
|
mock_redis.pipeline.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sets_expire_on_both_keys(self):
|
||||||
|
"""Pipeline should call expire for both daily and weekly keys."""
|
||||||
|
mock_pipe = self._make_pipeline_mock()
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||||
|
|
||||||
|
expire_calls = mock_pipe.expire.call_args_list
|
||||||
|
assert len(expire_calls) == 2
|
||||||
|
|
||||||
|
# Daily key TTL should be positive (seconds until next midnight)
|
||||||
|
daily_ttl = expire_calls[0].args[1]
|
||||||
|
assert daily_ttl >= 1
|
||||||
|
|
||||||
|
# Weekly key TTL should be positive (seconds until next Monday)
|
||||||
|
weekly_ttl = expire_calls[1].args[1]
|
||||||
|
assert weekly_ttl >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_redis_failure_gracefully(self):
|
||||||
|
"""Should not raise when Redis is unavailable."""
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
side_effect=ConnectionError("Redis down"),
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cost_weighted_counting(self):
|
||||||
|
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||||
|
mock_pipe = self._make_pipeline_mock()
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
await record_token_usage(
|
||||||
|
_USER,
|
||||||
|
prompt_tokens=100, # uncached → 100
|
||||||
|
completion_tokens=50, # output → 50
|
||||||
|
cache_read_tokens=10000, # 10% → 1000
|
||||||
|
cache_creation_tokens=400, # 25% → 100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||||
|
incrby_calls = mock_pipe.incrby.call_args_list
|
||||||
|
assert len(incrby_calls) == 2
|
||||||
|
assert incrby_calls[0].args[1] == 1250 # daily
|
||||||
|
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||||
|
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||||
|
mock_pipe = self._make_pipeline_mock()
|
||||||
|
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.rate_limit.get_redis_async",
|
||||||
|
return_value=mock_redis,
|
||||||
|
):
|
||||||
|
# Should not raise — fail-open
|
||||||
|
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||||
@@ -186,12 +186,29 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
|
|
||||||
|
|
||||||
class StreamUsage(StreamBaseResponse):
|
class StreamUsage(StreamBaseResponse):
|
||||||
"""Token usage statistics."""
|
"""Token usage statistics.
|
||||||
|
|
||||||
|
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||||
|
(it uses z.strictObject() and rejects unknown event types).
|
||||||
|
Usage data is recorded server-side (session DB + Redis counters).
|
||||||
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.USAGE
|
type: ResponseType = ResponseType.USAGE
|
||||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
|
||||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||||
totalTokens: int = Field(..., description="Total number of tokens")
|
totalTokens: int = Field(
|
||||||
|
..., description="Total number of tokens (raw, not weighted)"
|
||||||
|
)
|
||||||
|
cacheReadTokens: int = Field(
|
||||||
|
default=0, description="Prompt tokens served from cache (10% cost)"
|
||||||
|
)
|
||||||
|
cacheCreationTokens: int = Field(
|
||||||
|
default=0, description="Prompt tokens written to cache (25% cost)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||||
|
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamError(StreamBaseResponse):
|
class StreamError(StreamBaseResponse):
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class CompactionTracker:
|
|||||||
|
|
||||||
def reset_for_query(self) -> None:
|
def reset_for_query(self) -> None:
|
||||||
"""Reset per-query state before a new SDK query."""
|
"""Reset per-query state before a new SDK query."""
|
||||||
|
self._compact_start.clear()
|
||||||
self._done = False
|
self._done = False
|
||||||
self._start_emitted = False
|
self._start_emitted = False
|
||||||
self._tool_call_id = ""
|
self._tool_call_id = ""
|
||||||
|
|||||||
@@ -0,0 +1,546 @@
|
|||||||
|
"""End-to-end compaction flow test.
|
||||||
|
|
||||||
|
Simulates the full service.py compaction lifecycle using real-format
|
||||||
|
JSONL session files — no SDK subprocess needed. Exercises:
|
||||||
|
|
||||||
|
1. TranscriptBuilder loads a "downloaded" transcript
|
||||||
|
2. User query appended, assistant response streamed
|
||||||
|
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||||
|
4. Next message → emit_start_if_ready() yields spinner events
|
||||||
|
5. Message after that → emit_end_if_ready() returns end events
|
||||||
|
6. _read_compacted_entries() reads the CLI session file
|
||||||
|
7. TranscriptBuilder.replace_entries() syncs state
|
||||||
|
8. More messages appended post-compaction
|
||||||
|
9. to_jsonl() exports full state for upload
|
||||||
|
10. Fresh builder loads the export — roundtrip verified
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.copilot.response_model import (
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from backend.copilot.sdk.compaction import CompactionTracker
|
||||||
|
from backend.copilot.sdk.transcript import strip_progress_entries
|
||||||
|
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||||
|
from backend.util import json
|
||||||
|
|
||||||
|
|
||||||
|
def _make_jsonl(*entries: dict) -> str:
|
||||||
|
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
"""Run an async coroutine synchronously."""
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
|
||||||
|
"""Test-only: read compacted entries from a session JSONL file.
|
||||||
|
|
||||||
|
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
|
||||||
|
entry onward, or ``None`` if no summary is found.
|
||||||
|
"""
|
||||||
|
content = Path(path).read_text()
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
compact_idx: int | None = None
|
||||||
|
parsed: list[dict] = []
|
||||||
|
raw_lines: list[str] = []
|
||||||
|
for line in lines:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
entry = json.loads(line, fallback=None)
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
parsed.append(entry)
|
||||||
|
raw_lines.append(line.strip())
|
||||||
|
if compact_idx is None and entry.get("isCompactSummary"):
|
||||||
|
compact_idx = len(parsed) - 1
|
||||||
|
if compact_idx is None:
|
||||||
|
return None
|
||||||
|
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures: realistic CLI session file content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Pre-compaction conversation
|
||||||
|
USER_1 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "What files are in this project?"},
|
||||||
|
}
|
||||||
|
ASST_1_THINKING = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-think",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_aaa",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||||
|
"stop_reason": None,
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ASST_1_TOOL = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-tool",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_aaa",
|
||||||
|
"type": "message",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "tu1",
|
||||||
|
"name": "Bash",
|
||||||
|
"input": {"command": "ls"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stop_reason": "tool_use",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
TOOL_RESULT_1 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "tr1",
|
||||||
|
"parentUuid": "a1-tool",
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "tu1",
|
||||||
|
"content": "file1.py\nfile2.py",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ASST_1_TEXT = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-text",
|
||||||
|
"parentUuid": "tr1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_bbb",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# Progress entries (should be stripped during upload)
|
||||||
|
PROGRESS_1 = {
|
||||||
|
"type": "progress",
|
||||||
|
"uuid": "prog1",
|
||||||
|
"parentUuid": "a1-tool",
|
||||||
|
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||||
|
}
|
||||||
|
# Second user message
|
||||||
|
USER_2 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u2",
|
||||||
|
"parentUuid": "a1-text",
|
||||||
|
"message": {"role": "user", "content": "Show me file1.py"},
|
||||||
|
}
|
||||||
|
ASST_2 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "u2",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_ccc",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Compaction summary (written by CLI after context compaction) ---
|
||||||
|
COMPACT_SUMMARY = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||||
|
"User then asked to see file1.py."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Post-compaction assistant response
|
||||||
|
POST_COMPACT_ASST = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a3",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_ddd",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Post-compaction user follow-up
|
||||||
|
USER_3 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u3",
|
||||||
|
"parentUuid": "a3",
|
||||||
|
"message": {"role": "user", "content": "Now show file2.py"},
|
||||||
|
}
|
||||||
|
ASST_3 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a4",
|
||||||
|
"parentUuid": "u3",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_eee",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# E2E test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionE2E:
|
||||||
|
def _write_session_file(self, session_dir, entries):
|
||||||
|
"""Write a CLI session JSONL file."""
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
path.write_text(_make_jsonl(*entries))
|
||||||
|
return path
|
||||||
|
|
||||||
|
def test_full_compaction_lifecycle(self, tmp_path):
|
||||||
|
"""Simulate the complete service.py compaction flow.
|
||||||
|
|
||||||
|
Timeline:
|
||||||
|
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||||
|
2. Current turn: download → load_previous
|
||||||
|
3. User sends "Now show file2.py" → append_user
|
||||||
|
4. SDK starts streaming response
|
||||||
|
5. Mid-stream: PreCompact hook fires (context too large)
|
||||||
|
6. CLI writes compaction summary to session file
|
||||||
|
7. Next SDK message → emit_start (spinner)
|
||||||
|
8. Following message → emit_end (end events)
|
||||||
|
9. _read_compacted_entries reads the session file
|
||||||
|
10. replace_entries syncs TranscriptBuilder
|
||||||
|
11. More assistant messages appended
|
||||||
|
12. Export → upload → next turn downloads it
|
||||||
|
"""
|
||||||
|
session_dir = tmp_path / "session"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||||
|
previous_transcript = _make_jsonl(
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(previous_transcript)
|
||||||
|
assert builder.entry_count == 7
|
||||||
|
|
||||||
|
# --- Step 3: User sends new query ---
|
||||||
|
builder.append_user("Now show file2.py")
|
||||||
|
assert builder.entry_count == 8
|
||||||
|
|
||||||
|
# --- Step 4: SDK starts streaming ---
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
)
|
||||||
|
assert builder.entry_count == 9
|
||||||
|
|
||||||
|
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||||
|
session_file = self._write_session_file(
|
||||||
|
session_dir,
|
||||||
|
[
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
PROGRESS_1,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
COMPACT_SUMMARY,
|
||||||
|
POST_COMPACT_ASST,
|
||||||
|
USER_3,
|
||||||
|
ASST_3,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = ChatSession.new(user_id="test-user")
|
||||||
|
# on_compact is a property returning Event.set callable
|
||||||
|
tracker.on_compact()
|
||||||
|
|
||||||
|
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||||
|
start_events = tracker.emit_start_if_ready()
|
||||||
|
assert len(start_events) == 3
|
||||||
|
assert isinstance(start_events[0], StreamStartStep)
|
||||||
|
assert isinstance(start_events[1], StreamToolInputStart)
|
||||||
|
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||||
|
|
||||||
|
# Verify tool_call_id is set
|
||||||
|
tool_call_id = start_events[1].toolCallId
|
||||||
|
assert tool_call_id.startswith("compaction-")
|
||||||
|
|
||||||
|
# --- Step 9: Following message → emit_end ---
|
||||||
|
end_events = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert len(end_events) == 2
|
||||||
|
assert isinstance(end_events[0], StreamToolOutputAvailable)
|
||||||
|
assert isinstance(end_events[1], StreamFinishStep)
|
||||||
|
# Verify same tool_call_id
|
||||||
|
assert end_events[0].toolCallId == tool_call_id
|
||||||
|
|
||||||
|
# Session should have compaction messages persisted
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[0].role == "assistant"
|
||||||
|
assert session.messages[1].role == "tool"
|
||||||
|
|
||||||
|
# --- Step 10: _read_compacted_entries + replace_entries ---
|
||||||
|
result = _read_compacted_entries(str(session_file))
|
||||||
|
assert result is not None
|
||||||
|
compacted_dicts, compacted_jsonl = result
|
||||||
|
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||||
|
assert len(compacted_dicts) == 4
|
||||||
|
assert compacted_dicts[0]["uuid"] == "cs1"
|
||||||
|
assert compacted_dicts[0]["isCompactSummary"] is True
|
||||||
|
|
||||||
|
# Replace builder state with compacted JSONL
|
||||||
|
old_count = builder.entry_count
|
||||||
|
builder.replace_entries(compacted_jsonl)
|
||||||
|
assert builder.entry_count == 4 # Only compacted entries
|
||||||
|
assert builder.entry_count < old_count # Compaction reduced entries
|
||||||
|
|
||||||
|
# --- Step 11: More assistant messages after compaction ---
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
stop_reason="end_turn",
|
||||||
|
)
|
||||||
|
assert builder.entry_count == 5
|
||||||
|
|
||||||
|
# --- Step 12: Export for upload ---
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
assert output # Not empty
|
||||||
|
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert len(output_entries) == 5
|
||||||
|
|
||||||
|
# Verify structure:
|
||||||
|
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||||
|
assert output_entries[0]["type"] == "summary"
|
||||||
|
assert output_entries[0].get("isCompactSummary") is True
|
||||||
|
assert output_entries[0]["uuid"] == "cs1"
|
||||||
|
assert output_entries[1]["uuid"] == "a3"
|
||||||
|
assert output_entries[2]["uuid"] == "u3"
|
||||||
|
assert output_entries[3]["uuid"] == "a4"
|
||||||
|
assert output_entries[4]["type"] == "assistant"
|
||||||
|
|
||||||
|
# Verify parent chain is intact
|
||||||
|
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||||
|
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||||
|
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||||
|
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||||
|
|
||||||
|
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||||
|
builder2 = TranscriptBuilder()
|
||||||
|
builder2.load_previous(output)
|
||||||
|
assert builder2.entry_count == 5
|
||||||
|
|
||||||
|
# isCompactSummary survives roundtrip
|
||||||
|
output2 = builder2.to_jsonl()
|
||||||
|
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||||
|
assert first_entry.get("isCompactSummary") is True
|
||||||
|
|
||||||
|
# Can append more messages
|
||||||
|
builder2.append_user("What about file3.py?")
|
||||||
|
assert builder2.entry_count == 6
|
||||||
|
final_output = builder2.to_jsonl()
|
||||||
|
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||||
|
assert last_entry["type"] == "user"
|
||||||
|
# Parented to the last entry from previous turn
|
||||||
|
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||||
|
|
||||||
|
def test_double_compaction_within_session(self, tmp_path):
|
||||||
|
"""Two compactions in the same session (across reset_for_query)."""
|
||||||
|
session_dir = tmp_path / "session"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = ChatSession.new(user_id="test")
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
|
||||||
|
# --- First query with compaction ---
|
||||||
|
builder.append_user("first question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||||
|
|
||||||
|
# Write session file for first compaction
|
||||||
|
first_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs-first",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "First compaction summary"},
|
||||||
|
}
|
||||||
|
first_post = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a-first",
|
||||||
|
"parentUuid": "cs-first",
|
||||||
|
"message": {"role": "assistant", "content": "first post-compact"},
|
||||||
|
}
|
||||||
|
file1 = session_dir / "session1.jsonl"
|
||||||
|
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||||
|
|
||||||
|
tracker.on_compact()
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
end_events1 = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert len(end_events1) == 2 # output + finish
|
||||||
|
|
||||||
|
result1_entries = _read_compacted_entries(str(file1))
|
||||||
|
assert result1_entries is not None
|
||||||
|
_, compacted1_jsonl = result1_entries
|
||||||
|
builder.replace_entries(compacted1_jsonl)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# --- Reset for second query ---
|
||||||
|
tracker.reset_for_query()
|
||||||
|
|
||||||
|
# --- Second query with compaction ---
|
||||||
|
builder.append_user("second question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||||
|
|
||||||
|
second_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs-second",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "Second compaction summary"},
|
||||||
|
}
|
||||||
|
second_post = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a-second",
|
||||||
|
"parentUuid": "cs-second",
|
||||||
|
"message": {"role": "assistant", "content": "second post-compact"},
|
||||||
|
}
|
||||||
|
file2 = session_dir / "session2.jsonl"
|
||||||
|
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||||
|
|
||||||
|
tracker.on_compact()
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
end_events2 = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert len(end_events2) == 2 # output + finish
|
||||||
|
|
||||||
|
result2_entries = _read_compacted_entries(str(file2))
|
||||||
|
assert result2_entries is not None
|
||||||
|
_, compacted2_jsonl = result2_entries
|
||||||
|
builder.replace_entries(compacted2_jsonl)
|
||||||
|
assert builder.entry_count == 2 # Only second compaction entries
|
||||||
|
|
||||||
|
# Export and verify
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[0]["uuid"] == "cs-second"
|
||||||
|
assert entries[0].get("isCompactSummary") is True
|
||||||
|
|
||||||
|
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
|
||||||
|
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||||
|
|
||||||
|
This tests the exact sequence that happens across two turns:
|
||||||
|
Turn 1: SDK produces transcript with progress entries
|
||||||
|
Upload: strip_progress_entries removes progress, upload to cloud
|
||||||
|
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||||
|
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||||
|
"""
|
||||||
|
session_dir = tmp_path / "session"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# --- Turn 1: SDK produces raw transcript ---
|
||||||
|
raw_content = _make_jsonl(
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
PROGRESS_1,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip progress for upload
|
||||||
|
stripped = strip_progress_entries(raw_content)
|
||||||
|
stripped_entries = [
|
||||||
|
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||||
|
]
|
||||||
|
# Progress should be gone
|
||||||
|
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||||
|
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||||
|
|
||||||
|
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(stripped)
|
||||||
|
assert builder.entry_count == 7
|
||||||
|
|
||||||
|
builder.append_user("Now show file2.py")
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "text", "text": "Reading file2.py..."}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CLI writes session file with compaction
|
||||||
|
session_file = self._write_session_file(
|
||||||
|
session_dir,
|
||||||
|
[
|
||||||
|
USER_1,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
COMPACT_SUMMARY,
|
||||||
|
POST_COMPACT_ASST,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _read_compacted_entries(str(session_file))
|
||||||
|
assert result is not None
|
||||||
|
_, compacted_jsonl = result
|
||||||
|
builder.replace_entries(compacted_jsonl)
|
||||||
|
|
||||||
|
# Append post-compaction message
|
||||||
|
builder.append_user("Thanks!")
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
|
||||||
|
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||||
|
builder3 = TranscriptBuilder()
|
||||||
|
builder3.load_previous(output)
|
||||||
|
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||||
|
assert builder3.entry_count == 3
|
||||||
|
|
||||||
|
# Compact summary survived the full pipeline
|
||||||
|
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||||
|
assert first.get("isCompactSummary") is True
|
||||||
|
assert first["type"] == "summary"
|
||||||
@@ -221,12 +221,12 @@ class SDKResponseAdapter:
|
|||||||
responses.append(StreamFinish())
|
responses.append(StreamFinish())
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
|
||||||
)
|
)
|
||||||
responses.append(StreamFinish())
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
|||||||
if is_allowed_local_path(path, sdk_cwd):
|
if is_allowed_local_path(path, sdk_cwd):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||||
return _deny(
|
return _deny(
|
||||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
|||||||
"""
|
"""
|
||||||
# Block forbidden tools
|
# Block forbidden tools
|
||||||
if tool_name in BLOCKED_TOOLS:
|
if tool_name in BLOCKED_TOOLS:
|
||||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||||
return _deny(
|
return _deny(
|
||||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||||
"This is enforced by the platform and cannot be bypassed. "
|
"This is enforced by the platform and cannot be bypassed. "
|
||||||
@@ -111,7 +111,9 @@ def _validate_user_isolation(
|
|||||||
# the tool itself via _validate_ephemeral_path.
|
# the tool itself via _validate_ephemeral_path.
|
||||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||||
if path and ".." in path:
|
if path and ".." in path:
|
||||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
logger.warning(
|
||||||
|
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"hookSpecificOutput": {
|
"hookSpecificOutput": {
|
||||||
"hookEventName": "PreToolUse",
|
"hookEventName": "PreToolUse",
|
||||||
@@ -169,7 +171,7 @@ def create_security_hooks(
|
|||||||
# Block background task execution first — denied calls
|
# Block background task execution first — denied calls
|
||||||
# should not consume a subtask slot.
|
# should not consume a subtask slot.
|
||||||
if tool_input.get("run_in_background"):
|
if tool_input.get("run_in_background"):
|
||||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||||
return cast(
|
return cast(
|
||||||
SyncHookJSONOutput,
|
SyncHookJSONOutput,
|
||||||
_deny(
|
_deny(
|
||||||
@@ -211,7 +213,7 @@ def create_security_hooks(
|
|||||||
if tool_name == "Task" and tool_use_id is not None:
|
if tool_name == "Task" and tool_use_id is not None:
|
||||||
task_tool_use_ids.add(tool_use_id)
|
task_tool_use_ids.add(tool_use_id)
|
||||||
|
|
||||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||||
return cast(SyncHookJSONOutput, {})
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||||
|
|||||||
@@ -40,11 +40,13 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
|||||||
from ..model import (
|
from ..model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
|
Usage,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
from ..prompting import get_sdk_supplement
|
from ..prompting import get_sdk_supplement
|
||||||
|
from ..rate_limit import record_token_usage
|
||||||
from ..response_model import (
|
from ..response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
@@ -54,6 +56,7 @@ from ..response_model import (
|
|||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamToolInputAvailable,
|
StreamToolInputAvailable,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from ..service import (
|
from ..service import (
|
||||||
_build_system_prompt,
|
_build_system_prompt,
|
||||||
@@ -75,8 +78,12 @@ from .tool_adapter import (
|
|||||||
wait_for_stash,
|
wait_for_stash,
|
||||||
)
|
)
|
||||||
from .transcript import (
|
from .transcript import (
|
||||||
|
COMPACT_THRESHOLD_BYTES,
|
||||||
|
TranscriptDownload,
|
||||||
cleanup_cli_project_dir,
|
cleanup_cli_project_dir,
|
||||||
|
compact_transcript,
|
||||||
download_transcript,
|
download_transcript,
|
||||||
|
read_cli_session_file,
|
||||||
upload_transcript,
|
upload_transcript,
|
||||||
validate_transcript,
|
validate_transcript,
|
||||||
write_transcript_to_tempfile,
|
write_transcript_to_tempfile,
|
||||||
@@ -294,7 +301,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
|||||||
"""
|
"""
|
||||||
normalized = os.path.normpath(cwd)
|
normalized = os.path.normpath(cwd)
|
||||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Clean the CLI's project directory (transcripts + tool-results).
|
# Clean the CLI's project directory (transcripts + tool-results).
|
||||||
@@ -388,7 +395,7 @@ async def _compress_messages(
|
|||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
logger.warning("[SDK] Context compression with LLM failed: %s", e)
|
||||||
# Fall back to truncation-only (no LLM summarization)
|
# Fall back to truncation-only (no LLM summarization)
|
||||||
result = await compress_context(
|
result = await compress_context(
|
||||||
messages=messages_dict,
|
messages=messages_dict,
|
||||||
@@ -624,6 +631,56 @@ async def _prepare_file_attachments(
|
|||||||
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_compact_and_upload(
|
||||||
|
dl: TranscriptDownload,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
log_prefix: str = "[Transcript]",
|
||||||
|
) -> str:
|
||||||
|
"""Compact an oversized transcript and upload the compacted version.
|
||||||
|
|
||||||
|
Returns the (possibly compacted) transcript content, or an empty string
|
||||||
|
if compaction was needed but failed.
|
||||||
|
"""
|
||||||
|
content = dl.content
|
||||||
|
if len(content) <= COMPACT_THRESHOLD_BYTES:
|
||||||
|
return content
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"%s Transcript oversized (%dB > %dB), compacting",
|
||||||
|
log_prefix,
|
||||||
|
len(content),
|
||||||
|
COMPACT_THRESHOLD_BYTES,
|
||||||
|
)
|
||||||
|
compacted = await compact_transcript(content, log_prefix=log_prefix)
|
||||||
|
if not compacted:
|
||||||
|
logger.warning(
|
||||||
|
"%s Compaction failed, skipping resume for this turn", log_prefix
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Keep the original message_count: it reflects the number of
|
||||||
|
# session.messages covered by this transcript, which the gap-fill
|
||||||
|
# logic uses as a slice index. Counting JSONL lines would give a
|
||||||
|
# smaller number (compacted messages != session message count) and
|
||||||
|
# cause already-covered messages to be re-injected.
|
||||||
|
try:
|
||||||
|
await upload_transcript(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
content=compacted,
|
||||||
|
message_count=dl.message_count,
|
||||||
|
log_prefix=log_prefix,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"%s Failed to upload compacted transcript",
|
||||||
|
log_prefix,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return compacted
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completion_sdk(
|
async def stream_chat_completion_sdk(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: str | None = None,
|
message: str | None = None,
|
||||||
@@ -735,6 +792,14 @@ async def stream_chat_completion_sdk(
|
|||||||
_otel_ctx: Any = None
|
_otel_ctx: Any = None
|
||||||
|
|
||||||
# Make sure there is no more code between the lock acquisition and try-block.
|
# Make sure there is no more code between the lock acquisition and try-block.
|
||||||
|
# Token usage accumulators — populated from ResultMessage at end of turn
|
||||||
|
turn_prompt_tokens = 0 # uncached input tokens only
|
||||||
|
turn_completion_tokens = 0
|
||||||
|
turn_cache_read_tokens = 0
|
||||||
|
turn_cache_creation_tokens = 0
|
||||||
|
total_tokens = 0 # computed once before StreamUsage, reused in finally
|
||||||
|
turn_cost_usd: float | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||||
# Pre-compute the cwd here so the exact working directory path can be
|
# Pre-compute the cwd here so the exact working directory path can be
|
||||||
@@ -827,20 +892,33 @@ async def stream_chat_completion_sdk(
|
|||||||
is_valid,
|
is_valid,
|
||||||
)
|
)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
# Load previous FULL context into builder
|
transcript_content = await _maybe_compact_and_upload(
|
||||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
dl,
|
||||||
resume_file = write_transcript_to_tempfile(
|
user_id=user_id or "",
|
||||||
dl.content, session_id, sdk_cwd
|
session_id=session_id,
|
||||||
|
log_prefix=log_prefix,
|
||||||
|
)
|
||||||
|
# Load previous context into builder (empty string is a no-op)
|
||||||
|
if transcript_content:
|
||||||
|
transcript_builder.load_previous(
|
||||||
|
transcript_content, log_prefix=log_prefix
|
||||||
|
)
|
||||||
|
resume_file = (
|
||||||
|
write_transcript_to_tempfile(
|
||||||
|
transcript_content, session_id, sdk_cwd
|
||||||
|
)
|
||||||
|
if transcript_content
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if resume_file:
|
if resume_file:
|
||||||
use_resume = True
|
use_resume = True
|
||||||
transcript_msg_count = dl.message_count
|
transcript_msg_count = dl.message_count
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
|
||||||
f"msg_count={transcript_msg_count})"
|
f"msg_count={transcript_msg_count})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{log_prefix} No transcript available "
|
f"{log_prefix} No transcript available "
|
||||||
@@ -1110,7 +1188,7 @@ async def stream_chat_completion_sdk(
|
|||||||
- len(adapter.resolved_tool_calls),
|
- len(adapter.resolved_tool_calls),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log ResultMessage details for debugging
|
# Log ResultMessage details and capture token usage
|
||||||
if isinstance(sdk_msg, ResultMessage):
|
if isinstance(sdk_msg, ResultMessage):
|
||||||
logger.info(
|
logger.info(
|
||||||
"%s Received: ResultMessage %s "
|
"%s Received: ResultMessage %s "
|
||||||
@@ -1129,9 +1207,46 @@ async def stream_chat_completion_sdk(
|
|||||||
sdk_msg.result or "(no error message provided)",
|
sdk_msg.result or "(no error message provided)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit compaction end if SDK finished compacting
|
# Capture token usage from ResultMessage.
|
||||||
for ev in await compaction.emit_end_if_ready(session):
|
# Anthropic reports cached tokens separately:
|
||||||
|
# input_tokens = uncached only
|
||||||
|
# cache_read_input_tokens = served from cache
|
||||||
|
# cache_creation_input_tokens = written to cache
|
||||||
|
if sdk_msg.usage:
|
||||||
|
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||||
|
turn_cache_read_tokens += sdk_msg.usage.get(
|
||||||
|
"cache_read_input_tokens", 0
|
||||||
|
)
|
||||||
|
turn_cache_creation_tokens += sdk_msg.usage.get(
|
||||||
|
"cache_creation_input_tokens", 0
|
||||||
|
)
|
||||||
|
turn_completion_tokens += sdk_msg.usage.get(
|
||||||
|
"output_tokens", 0
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
|
||||||
|
log_prefix,
|
||||||
|
turn_prompt_tokens,
|
||||||
|
turn_cache_read_tokens,
|
||||||
|
turn_cache_creation_tokens,
|
||||||
|
turn_completion_tokens,
|
||||||
|
)
|
||||||
|
if sdk_msg.total_cost_usd is not None:
|
||||||
|
turn_cost_usd = sdk_msg.total_cost_usd
|
||||||
|
|
||||||
|
# Emit compaction end if SDK finished compacting.
|
||||||
|
# When compaction ends, sync TranscriptBuilder with
|
||||||
|
# the CLI's compacted session file so the uploaded
|
||||||
|
# transcript reflects compaction.
|
||||||
|
compaction_events = await compaction.emit_end_if_ready(session)
|
||||||
|
for ev in compaction_events:
|
||||||
yield ev
|
yield ev
|
||||||
|
if compaction_events and sdk_cwd:
|
||||||
|
cli_content = await read_cli_session_file(sdk_cwd)
|
||||||
|
if cli_content:
|
||||||
|
transcript_builder.replace_entries(
|
||||||
|
cli_content, log_prefix=log_prefix
|
||||||
|
)
|
||||||
|
|
||||||
for response in adapter.convert_message(sdk_msg):
|
for response in adapter.convert_message(sdk_msg):
|
||||||
if isinstance(response, StreamStart):
|
if isinstance(response, StreamStart):
|
||||||
@@ -1325,6 +1440,27 @@ async def stream_chat_completion_sdk(
|
|||||||
) and not has_appended_assistant:
|
) and not has_appended_assistant:
|
||||||
session.messages.append(assistant_response)
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
|
# Emit token usage to the client (must be in try to reach SSE stream).
|
||||||
|
# Session persistence of usage is in finally to stay consistent with
|
||||||
|
# rate-limit recording even if an exception interrupts between here
|
||||||
|
# and the finally block.
|
||||||
|
# Compute total_tokens once; reused in the finally block for
|
||||||
|
# session persistence and rate-limit recording.
|
||||||
|
total_tokens = (
|
||||||
|
turn_prompt_tokens
|
||||||
|
+ turn_cache_read_tokens
|
||||||
|
+ turn_cache_creation_tokens
|
||||||
|
+ turn_completion_tokens
|
||||||
|
)
|
||||||
|
if total_tokens > 0:
|
||||||
|
yield StreamUsage(
|
||||||
|
promptTokens=turn_prompt_tokens,
|
||||||
|
completionTokens=turn_completion_tokens,
|
||||||
|
totalTokens=total_tokens,
|
||||||
|
cacheReadTokens=turn_cache_read_tokens,
|
||||||
|
cacheCreationTokens=turn_cache_creation_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Transcript upload is handled exclusively in the finally block
|
# Transcript upload is handled exclusively in the finally block
|
||||||
# to avoid double-uploads (the success path used to upload the
|
# to avoid double-uploads (the success path used to upload the
|
||||||
# old resume file, then the finally block overwrote it with the
|
# old resume file, then the finally block overwrote it with the
|
||||||
@@ -1389,6 +1525,48 @@ async def stream_chat_completion_sdk(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("OTEL context teardown failed", exc_info=True)
|
logger.warning("OTEL context teardown failed", exc_info=True)
|
||||||
|
|
||||||
|
# --- Persist token usage to session + rate-limit counters ---
|
||||||
|
# Both must live in finally so they stay consistent even when an
|
||||||
|
# exception interrupts the try block after StreamUsage was yielded.
|
||||||
|
# total_tokens is computed once before StreamUsage yield above.
|
||||||
|
if total_tokens > 0:
|
||||||
|
if session is not None:
|
||||||
|
session.usage.append(
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=turn_prompt_tokens,
|
||||||
|
completion_tokens=turn_completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cache_read_tokens=turn_cache_read_tokens,
|
||||||
|
cache_creation_tokens=turn_cache_creation_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
|
||||||
|
"output=%d, total=%d, cost_usd=%s",
|
||||||
|
log_prefix,
|
||||||
|
turn_prompt_tokens,
|
||||||
|
turn_cache_read_tokens,
|
||||||
|
turn_cache_creation_tokens,
|
||||||
|
turn_completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
turn_cost_usd,
|
||||||
|
)
|
||||||
|
if user_id and total_tokens > 0:
|
||||||
|
try:
|
||||||
|
await record_token_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
prompt_tokens=turn_prompt_tokens,
|
||||||
|
completion_tokens=turn_completion_tokens,
|
||||||
|
cache_read_tokens=turn_cache_read_tokens,
|
||||||
|
cache_creation_tokens=turn_cache_creation_tokens,
|
||||||
|
)
|
||||||
|
except Exception as usage_err:
|
||||||
|
logger.warning(
|
||||||
|
"%s Failed to record token usage: %s",
|
||||||
|
log_prefix,
|
||||||
|
usage_err,
|
||||||
|
)
|
||||||
|
|
||||||
# --- Persist session messages ---
|
# --- Persist session messages ---
|
||||||
# This MUST run in finally to persist messages even when the generator
|
# This MUST run in finally to persist messages even when the generator
|
||||||
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
||||||
@@ -1484,6 +1662,6 @@ async def _update_title_async(
|
|||||||
)
|
)
|
||||||
if title and user_id:
|
if title and user_id:
|
||||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||||
|
|||||||
@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
|
|||||||
try:
|
try:
|
||||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||||
|
)
|
||||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||||
|
|
||||||
return tool_handler
|
return tool_handler
|
||||||
|
|||||||
@@ -13,10 +13,17 @@ filesystem for self-hosted) — no DB column needed.
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from backend.copilot.config import ChatConfig
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
from backend.util.prompt import CompressResult, compress_context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,6 +41,11 @@ STRIPPABLE_TYPES = frozenset(
|
|||||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# JSONL protocol values used in transcript serialization.
|
||||||
|
STOP_REASON_END_TURN = "end_turn"
|
||||||
|
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||||
|
ENTRY_TYPE_MESSAGE = "message"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TranscriptDownload:
|
class TranscriptDownload:
|
||||||
@@ -82,7 +94,11 @@ def strip_progress_entries(content: str) -> str:
|
|||||||
parent = entry.get("parentUuid", "")
|
parent = entry.get("parentUuid", "")
|
||||||
if uid:
|
if uid:
|
||||||
uuid_to_parent[uid] = parent
|
uuid_to_parent[uid] = parent
|
||||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
if (
|
||||||
|
entry.get("type", "") in STRIPPABLE_TYPES
|
||||||
|
and uid
|
||||||
|
and not entry.get("isCompactSummary")
|
||||||
|
):
|
||||||
stripped_uuids.add(uid)
|
stripped_uuids.add(uid)
|
||||||
|
|
||||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||||
@@ -93,7 +109,9 @@ def strip_progress_entries(content: str) -> str:
|
|||||||
continue
|
continue
|
||||||
parent = entry.get("parentUuid", "")
|
parent = entry.get("parentUuid", "")
|
||||||
original_parent = parent
|
original_parent = parent
|
||||||
while parent in stripped_uuids:
|
seen_parents: set[str] = set()
|
||||||
|
while parent in stripped_uuids and parent not in seen_parents:
|
||||||
|
seen_parents.add(parent)
|
||||||
parent = uuid_to_parent.get(parent, "")
|
parent = uuid_to_parent.get(parent, "")
|
||||||
if parent != original_parent:
|
if parent != original_parent:
|
||||||
entry["parentUuid"] = parent
|
entry["parentUuid"] = parent
|
||||||
@@ -106,7 +124,9 @@ def strip_progress_entries(content: str) -> str:
|
|||||||
if not isinstance(entry, dict):
|
if not isinstance(entry, dict):
|
||||||
result_lines.append(line)
|
result_lines.append(line)
|
||||||
continue
|
continue
|
||||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||||
|
"isCompactSummary"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
uid = entry.get("uuid", "")
|
uid = entry.get("uuid", "")
|
||||||
if uid in reparented:
|
if uid in reparented:
|
||||||
@@ -137,32 +157,78 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
|||||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||||
|
|
||||||
|
|
||||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||||
"""Remove the CLI's project directory for a specific working directory.
|
"""Return the CLI's project directory for a given working directory.
|
||||||
|
|
||||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
Returns ``None`` if the path would escape the projects base.
|
||||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
|
||||||
safe to remove entirely after the transcript has been uploaded.
|
|
||||||
"""
|
"""
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
|
||||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||||
|
|
||||||
if not project_dir.startswith(projects_base + os.sep):
|
if not project_dir.startswith(projects_base + os.sep):
|
||||||
logger.warning(
|
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
|
||||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
return None
|
||||||
)
|
return project_dir
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
async def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||||
|
"""Read the CLI's own session file, which reflects any mid-stream compaction.
|
||||||
|
|
||||||
|
After the CLI compacts context, its session file contains the compacted
|
||||||
|
conversation. Reading this file lets ``TranscriptBuilder`` replace its
|
||||||
|
uncompacted entries with the CLI's compacted version.
|
||||||
|
"""
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
|
project_dir = _cli_project_dir(sdk_cwd)
|
||||||
|
if not project_dir or not os.path.isdir(project_dir):
|
||||||
|
return None
|
||||||
|
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
|
||||||
|
if not jsonl_files:
|
||||||
|
logger.debug("[Transcript] No CLI session file in %s", project_dir)
|
||||||
|
return None
|
||||||
|
# Pick the most recently modified file (there should only be one per turn).
|
||||||
|
# Guard against races where a file is deleted between glob and stat.
|
||||||
|
candidates: list[tuple[float, Path]] = []
|
||||||
|
for p in jsonl_files:
|
||||||
|
try:
|
||||||
|
candidates.append((p.stat().st_mtime, p))
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
if not candidates:
|
||||||
|
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
|
||||||
|
return None
|
||||||
|
# Resolve + prefix check to prevent symlink escapes.
|
||||||
|
session_file = max(candidates, key=lambda item: item[0])[1]
|
||||||
|
real_path = str(session_file.resolve())
|
||||||
|
if not real_path.startswith(project_dir + os.sep):
|
||||||
|
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(real_path) as f:
|
||||||
|
content = await f.read()
|
||||||
|
logger.info(
|
||||||
|
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||||
|
real_path,
|
||||||
|
len(content),
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||||
|
"""Remove the CLI's project directory for a specific working directory."""
|
||||||
|
project_dir = _cli_project_dir(sdk_cwd)
|
||||||
|
if not project_dir:
|
||||||
|
return
|
||||||
if os.path.isdir(project_dir):
|
if os.path.isdir(project_dir):
|
||||||
shutil.rmtree(project_dir, ignore_errors=True)
|
shutil.rmtree(project_dir, ignore_errors=True)
|
||||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||||
|
|
||||||
|
|
||||||
def write_transcript_to_tempfile(
|
def write_transcript_to_tempfile(
|
||||||
@@ -180,7 +246,7 @@ def write_transcript_to_tempfile(
|
|||||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||||
real_cwd = os.path.realpath(cwd)
|
real_cwd = os.path.realpath(cwd)
|
||||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -190,17 +256,17 @@ def write_transcript_to_tempfile(
|
|||||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||||
)
|
)
|
||||||
if not jsonl_path.startswith(real_cwd):
|
if not jsonl_path.startswith(real_cwd):
|
||||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(jsonl_path, "w") as f:
|
with open(jsonl_path, "w") as f:
|
||||||
f.write(transcript_content)
|
f.write(transcript_content)
|
||||||
|
|
||||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||||
return jsonl_path
|
return jsonl_path
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -344,11 +410,14 @@ async def upload_transcript(
|
|||||||
content=json.dumps(meta).encode("utf-8"),
|
content=json.dumps(meta).encode("utf-8"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
log_prefix,
|
||||||
|
len(encoded),
|
||||||
|
len(content),
|
||||||
|
message_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -371,10 +440,10 @@ async def download_transcript(
|
|||||||
data = await storage.retrieve(path)
|
data = await storage.retrieve(path)
|
||||||
content = data.decode("utf-8")
|
content = data.decode("utf-8")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.debug(f"{log_prefix} No transcript in storage")
|
logger.debug("%s No transcript in storage", log_prefix)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||||
@@ -394,10 +463,14 @@ async def download_transcript(
|
|||||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||||
message_count = meta.get("message_count", 0)
|
message_count = meta.get("message_count", 0)
|
||||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||||
except (FileNotFoundError, Exception):
|
except FileNotFoundError:
|
||||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||||
|
|
||||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
logger.info(
|
||||||
|
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||||
|
)
|
||||||
return TranscriptDownload(
|
return TranscriptDownload(
|
||||||
content=content,
|
content=content,
|
||||||
message_count=message_count,
|
message_count=message_count,
|
||||||
@@ -405,15 +478,171 @@ async def download_transcript(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
# ---------------------------------------------------------------------------
|
||||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
# Transcript compaction
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
# Transcripts above this byte threshold are compacted at download time.
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
COMPACT_THRESHOLD_BYTES = 400_000
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_assistant_content(blocks: list) -> str:
|
||||||
|
"""Flatten assistant content blocks into a single plain-text string."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in blocks:
|
||||||
|
if isinstance(block, dict):
|
||||||
|
if block.get("type") == "text":
|
||||||
|
parts.append(block.get("text", ""))
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||||
|
elif isinstance(block, str):
|
||||||
|
parts.append(block)
|
||||||
|
return "\n".join(parts) if parts else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_tool_result_content(blocks: list) -> str:
|
||||||
|
"""Flatten tool_result and other content blocks into plain text.
|
||||||
|
|
||||||
|
Handles nested tool_result structures, text blocks, and raw strings.
|
||||||
|
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||||
|
or where ``text`` is ``None``.
|
||||||
|
"""
|
||||||
|
str_parts: list[str] = []
|
||||||
|
for block in blocks:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||||
|
inner = block.get("content", "")
|
||||||
|
if isinstance(inner, list):
|
||||||
|
for sub in inner:
|
||||||
|
if isinstance(sub, dict):
|
||||||
|
text = sub.get("text")
|
||||||
|
str_parts.append(
|
||||||
|
str(text) if text is not None else json.dumps(sub)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
str_parts.append(str(sub))
|
||||||
|
else:
|
||||||
|
str_parts.append(str(inner))
|
||||||
|
elif isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
str_parts.append(str(block.get("text", "")))
|
||||||
|
elif isinstance(block, str):
|
||||||
|
str_parts.append(block)
|
||||||
|
return "\n".join(str_parts) if str_parts else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _transcript_to_messages(content: str) -> list[dict]:
|
||||||
|
"""Convert JSONL transcript entries to message dicts for compress_context."""
|
||||||
|
messages: list[dict] = []
|
||||||
|
for line in content.strip().split("\n"):
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
entry = json.loads(line, fallback=None)
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||||
|
"isCompactSummary"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
msg = entry.get("message", {})
|
||||||
|
role = msg.get("role", "")
|
||||||
|
if not role:
|
||||||
|
continue
|
||||||
|
msg_dict: dict = {"role": role}
|
||||||
|
raw_content = msg.get("content")
|
||||||
|
if role == "assistant" and isinstance(raw_content, list):
|
||||||
|
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||||
|
elif isinstance(raw_content, list):
|
||||||
|
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||||
|
else:
|
||||||
|
msg_dict["content"] = raw_content or ""
|
||||||
|
messages.append(msg_dict)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||||
|
"""Convert compressed message dicts back to JSONL transcript format."""
|
||||||
|
lines: list[str] = []
|
||||||
|
last_uuid: str | None = None
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
entry_type = "assistant" if role == "assistant" else "user"
|
||||||
|
uid = str(uuid4())
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role == "assistant":
|
||||||
|
message: dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "",
|
||||||
|
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||||
|
"type": ENTRY_TYPE_MESSAGE,
|
||||||
|
"content": [{"type": "text", "text": content}] if content else [],
|
||||||
|
"stop_reason": STOP_REASON_END_TURN,
|
||||||
|
"stop_sequence": None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message = {"role": role, "content": content}
|
||||||
|
entry = {
|
||||||
|
"type": entry_type,
|
||||||
|
"uuid": uid,
|
||||||
|
"parentUuid": last_uuid,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||||
|
last_uuid = uid
|
||||||
|
return "\n".join(lines) + "\n" if lines else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_compression(
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
cfg: ChatConfig,
|
||||||
|
log_prefix: str,
|
||||||
|
) -> CompressResult:
|
||||||
|
"""Run LLM-based compression with truncation fallback."""
|
||||||
try:
|
try:
|
||||||
await storage.delete(path)
|
async with openai.AsyncOpenAI(
|
||||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
|
||||||
|
) as client:
|
||||||
|
return await compress_context(messages=messages, model=model, client=client)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||||
|
return await compress_context(messages=messages, model=model, client=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def compact_transcript(
|
||||||
|
content: str,
|
||||||
|
log_prefix: str = "[Transcript]",
|
||||||
|
) -> str | None:
|
||||||
|
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||||
|
|
||||||
|
Converts transcript entries to plain messages, runs ``compress_context``
|
||||||
|
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||||
|
|
||||||
|
Returns the compacted JSONL string, or ``None`` on failure.
|
||||||
|
"""
|
||||||
|
cfg = ChatConfig()
|
||||||
|
messages = _transcript_to_messages(content)
|
||||||
|
if len(messages) < 2:
|
||||||
|
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
|
||||||
|
if not result.was_compacted:
|
||||||
|
logger.info("%s Transcript already within token budget", log_prefix)
|
||||||
|
return content
|
||||||
|
logger.info(
|
||||||
|
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||||
|
log_prefix,
|
||||||
|
result.original_token_count,
|
||||||
|
result.token_count,
|
||||||
|
result.messages_summarized,
|
||||||
|
result.messages_dropped,
|
||||||
|
)
|
||||||
|
compacted = _messages_to_transcript(result.messages)
|
||||||
|
if not validate_transcript(compacted):
|
||||||
|
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||||
|
return None
|
||||||
|
return compacted
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class TranscriptEntry(BaseModel):
|
|||||||
uuid: str
|
uuid: str
|
||||||
parentUuid: str | None
|
parentUuid: str | None
|
||||||
message: dict[str, Any]
|
message: dict[str, Any]
|
||||||
|
isCompactSummary: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class TranscriptBuilder:
|
class TranscriptBuilder:
|
||||||
@@ -78,10 +79,12 @@ class TranscriptBuilder:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
|
||||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
# Compaction summaries may have type "summary" but must be preserved
|
||||||
|
# so --resume can reconstruct the compacted conversation.
|
||||||
entry_type = data.get("type", "")
|
entry_type = data.get("type", "")
|
||||||
if entry_type in STRIPPABLE_TYPES:
|
is_compact = data.get("isCompactSummary", False)
|
||||||
|
if entry_type in STRIPPABLE_TYPES and not is_compact:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
entry = TranscriptEntry(
|
entry = TranscriptEntry(
|
||||||
@@ -89,6 +92,7 @@ class TranscriptBuilder:
|
|||||||
uuid=data.get("uuid") or str(uuid4()),
|
uuid=data.get("uuid") or str(uuid4()),
|
||||||
parentUuid=data.get("parentUuid"),
|
parentUuid=data.get("parentUuid"),
|
||||||
message=data.get("message", {}),
|
message=data.get("message", {}),
|
||||||
|
isCompactSummary=True if is_compact else None,
|
||||||
)
|
)
|
||||||
self._entries.append(entry)
|
self._entries.append(entry)
|
||||||
self._last_uuid = entry.uuid
|
self._last_uuid = entry.uuid
|
||||||
@@ -177,6 +181,33 @@ class TranscriptBuilder:
|
|||||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||||
return "\n".join(lines) + "\n"
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||||
|
"""Replace all entries with compacted JSONL content.
|
||||||
|
|
||||||
|
Called after the CLI performs mid-stream compaction so the builder's
|
||||||
|
state reflects the compacted conversation instead of the full
|
||||||
|
pre-compaction history.
|
||||||
|
"""
|
||||||
|
prev_count = len(self._entries)
|
||||||
|
temp = TranscriptBuilder()
|
||||||
|
try:
|
||||||
|
temp.load_previous(content, log_prefix=log_prefix)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"%s Failed to parse compacted transcript; keeping %d existing entries",
|
||||||
|
log_prefix,
|
||||||
|
prev_count,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._entries = temp._entries
|
||||||
|
self._last_uuid = temp._last_uuid
|
||||||
|
logger.info(
|
||||||
|
"%s Replaced %d entries with %d compacted entries",
|
||||||
|
log_prefix,
|
||||||
|
prev_count,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entry_count(self) -> int:
|
def entry_count(self) -> int:
|
||||||
"""Total number of entries in the complete context."""
|
"""Total number of entries in the complete context."""
|
||||||
|
|||||||
@@ -2,14 +2,25 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
from .transcript import (
|
from .transcript import (
|
||||||
|
COMPACT_MSG_ID_PREFIX,
|
||||||
STRIPPABLE_TYPES,
|
STRIPPABLE_TYPES,
|
||||||
|
_cli_project_dir,
|
||||||
|
_flatten_assistant_content,
|
||||||
|
_flatten_tool_result_content,
|
||||||
|
_messages_to_transcript,
|
||||||
|
_transcript_to_messages,
|
||||||
|
compact_transcript,
|
||||||
|
read_cli_session_file,
|
||||||
strip_progress_entries,
|
strip_progress_entries,
|
||||||
validate_transcript,
|
validate_transcript,
|
||||||
write_transcript_to_tempfile,
|
write_transcript_to_tempfile,
|
||||||
)
|
)
|
||||||
|
from .transcript_builder import TranscriptBuilder
|
||||||
|
|
||||||
|
|
||||||
def _make_jsonl(*entries: dict) -> str:
|
def _make_jsonl(*entries: dict) -> str:
|
||||||
@@ -35,6 +46,14 @@ PROGRESS_ENTRY = {
|
|||||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
COMPACT_SUMMARY = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"parentUuid": None,
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "Summary of previous conversation..."},
|
||||||
|
}
|
||||||
|
|
||||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||||
|
|
||||||
|
|
||||||
@@ -237,6 +256,121 @@ class TestStripProgressEntries:
|
|||||||
# Should return just a newline (empty content stripped)
|
# Should return just a newline (empty content stripped)
|
||||||
assert result.strip() == ""
|
assert result.strip() == ""
|
||||||
|
|
||||||
|
|
||||||
|
# --- _cli_project_dir ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliProjectDir:
|
||||||
|
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||||
|
projects = tmp_path / "projects"
|
||||||
|
projects.mkdir()
|
||||||
|
result = _cli_project_dir("/tmp/copilot-abc")
|
||||||
|
assert result is not None
|
||||||
|
assert "projects" in result
|
||||||
|
|
||||||
|
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||||
|
projects = tmp_path / "projects"
|
||||||
|
projects.mkdir()
|
||||||
|
# A cwd that encodes to something with .. shouldn't escape
|
||||||
|
result = _cli_project_dir("/tmp/copilot-test")
|
||||||
|
# Should return a valid path (no traversal possible with alphanum encoding)
|
||||||
|
assert result is None or result.startswith(str(projects))
|
||||||
|
|
||||||
|
|
||||||
|
# --- read_cli_session_file ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadCliSessionFile:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reads_session_file(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||||
|
# Create the CLI project directory structure
|
||||||
|
cwd = "/tmp/copilot-testread"
|
||||||
|
import re
|
||||||
|
|
||||||
|
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||||
|
project_dir = tmp_path / "projects" / encoded
|
||||||
|
project_dir.mkdir(parents=True)
|
||||||
|
# Write a session file
|
||||||
|
session_file = project_dir / "test-session.jsonl"
|
||||||
|
session_file.write_text(json.dumps(ASST_MSG) + "\n")
|
||||||
|
|
||||||
|
result = await read_cli_session_file(cwd)
|
||||||
|
assert result is not None
|
||||||
|
assert "assistant" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||||
|
cwd = "/tmp/copilot-nofiles"
|
||||||
|
import re
|
||||||
|
|
||||||
|
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||||
|
project_dir = tmp_path / "projects" / encoded
|
||||||
|
project_dir.mkdir(parents=True)
|
||||||
|
# No jsonl files
|
||||||
|
result = await read_cli_session_file(cwd)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||||
|
(tmp_path / "projects").mkdir()
|
||||||
|
result = await read_cli_session_file("/tmp/copilot-nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- _transcript_to_messages / _messages_to_transcript ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptMessageConversion:
|
||||||
|
def test_roundtrip_preserves_roles(self):
|
||||||
|
transcript = _make_jsonl(USER_MSG, ASST_MSG)
|
||||||
|
messages = _transcript_to_messages(transcript)
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert messages[0]["role"] == "user"
|
||||||
|
assert messages[1]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_messages_to_transcript_produces_valid_jsonl(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{"role": "assistant", "content": "hello"},
|
||||||
|
]
|
||||||
|
result = _messages_to_transcript(messages)
|
||||||
|
assert validate_transcript(result) is True
|
||||||
|
|
||||||
|
def test_strips_strippable_types(self):
|
||||||
|
transcript = _make_jsonl(
|
||||||
|
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
|
||||||
|
USER_MSG,
|
||||||
|
ASST_MSG,
|
||||||
|
)
|
||||||
|
messages = _transcript_to_messages(transcript)
|
||||||
|
assert len(messages) == 2 # progress entry skipped
|
||||||
|
|
||||||
|
def test_flattens_assistant_content_blocks(self):
|
||||||
|
asst_with_blocks = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "hello"},
|
||||||
|
{"type": "tool_use", "name": "bash"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert "hello" in messages[0]["content"]
|
||||||
|
assert "[tool_use: bash]" in messages[0]["content"]
|
||||||
|
|
||||||
|
def test_empty_messages_returns_empty(self):
|
||||||
|
result = _messages_to_transcript([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
def test_no_strippable_entries(self):
|
def test_no_strippable_entries(self):
|
||||||
"""When there's nothing to strip, output matches input structure."""
|
"""When there's nothing to strip, output matches input structure."""
|
||||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||||
@@ -282,3 +416,654 @@ class TestStripProgressEntries:
|
|||||||
lines = result.strip().split("\n")
|
lines = result.strip().split("\n")
|
||||||
asst_entry = json.loads(lines[-1])
|
asst_entry = json.loads(lines[-1])
|
||||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||||
|
|
||||||
|
|
||||||
|
# --- TranscriptBuilder ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptBuilderReplaceEntries:
|
||||||
|
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
|
||||||
|
|
||||||
|
def test_replace_entries_with_valid_content(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# Replace with compacted content (one user + one assistant)
|
||||||
|
compacted = _make_jsonl(USER_MSG, ASST_MSG)
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
def test_replace_entries_keeps_old_on_corrupt_content(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
assert builder.entry_count == 1
|
||||||
|
|
||||||
|
# Corrupt content that fails to parse
|
||||||
|
builder.replace_entries("not valid json at all\n")
|
||||||
|
# Should still have old entries (load_previous skips invalid lines,
|
||||||
|
# but if ALL lines are invalid, temp builder is empty → exception path)
|
||||||
|
assert builder.entry_count >= 0 # doesn't crash
|
||||||
|
|
||||||
|
def test_replace_entries_with_empty_content(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
assert builder.entry_count == 1
|
||||||
|
|
||||||
|
builder.replace_entries("")
|
||||||
|
# Empty content → load_previous returns early → temp is empty
|
||||||
|
# replace_entries swaps to empty (0 entries)
|
||||||
|
assert builder.entry_count == 0
|
||||||
|
|
||||||
|
def test_replace_entries_filters_strippable_types(self):
|
||||||
|
"""Strippable types (progress, file-history-snapshot) are filtered out."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
|
||||||
|
content = _make_jsonl(
|
||||||
|
{"type": "progress", "uuid": "p1", "message": {}},
|
||||||
|
USER_MSG,
|
||||||
|
ASST_MSG,
|
||||||
|
)
|
||||||
|
builder.replace_entries(content)
|
||||||
|
# Only user + assistant should remain (progress filtered)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
def test_replace_entries_preserves_uuids(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||||
|
builder.replace_entries(content)
|
||||||
|
|
||||||
|
jsonl = builder.to_jsonl()
|
||||||
|
lines = jsonl.strip().split("\n")
|
||||||
|
first = json.loads(lines[0])
|
||||||
|
assert first["uuid"] == "u1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptBuilderBasic:
|
||||||
|
def test_append_user_and_assistant(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hi")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "hello"}])
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
assert not builder.is_empty
|
||||||
|
|
||||||
|
def test_to_jsonl_empty(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
assert builder.to_jsonl() == ""
|
||||||
|
assert builder.is_empty
|
||||||
|
|
||||||
|
def test_load_previous_and_append(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||||
|
builder.load_previous(content)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
builder.append_user("new message")
|
||||||
|
assert builder.entry_count == 3
|
||||||
|
|
||||||
|
def test_consecutive_assistant_entries_share_message_id(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hi")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "part1"}])
|
||||||
|
builder.append_assistant([{"type": "text", "text": "part2"}])
|
||||||
|
|
||||||
|
jsonl = builder.to_jsonl()
|
||||||
|
lines = jsonl.strip().split("\n")
|
||||||
|
asst1 = json.loads(lines[1])
|
||||||
|
asst2 = json.loads(lines[2])
|
||||||
|
assert asst1["message"]["id"] == asst2["message"]["id"]
|
||||||
|
|
||||||
|
def test_non_consecutive_assistant_entries_get_new_id(self):
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hi")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "response1"}])
|
||||||
|
builder.append_user("followup")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "response2"}])
|
||||||
|
|
||||||
|
jsonl = builder.to_jsonl()
|
||||||
|
lines = jsonl.strip().split("\n")
|
||||||
|
asst1 = json.loads(lines[1])
|
||||||
|
asst2 = json.loads(lines[3])
|
||||||
|
assert asst1["message"]["id"] != asst2["message"]["id"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactSummaryRoundtrip:
|
||||||
|
"""Verify isCompactSummary survives export→reload roundtrip."""
|
||||||
|
|
||||||
|
def test_load_previous_preserves_compact_summary(self):
|
||||||
|
"""Compaction summary with type 'summary' should not be stripped."""
|
||||||
|
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(content)
|
||||||
|
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
|
||||||
|
assert builder.entry_count == 3
|
||||||
|
|
||||||
|
def test_export_reload_preserves_compact_summary(self):
|
||||||
|
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
|
||||||
|
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||||
|
builder1 = TranscriptBuilder()
|
||||||
|
builder1.load_previous(content)
|
||||||
|
assert builder1.entry_count == 3
|
||||||
|
|
||||||
|
exported = builder1.to_jsonl()
|
||||||
|
# Verify isCompactSummary is in the exported JSONL
|
||||||
|
first_line = json.loads(exported.strip().split("\n")[0])
|
||||||
|
assert first_line.get("isCompactSummary") is True
|
||||||
|
|
||||||
|
# Reload and verify it's still preserved
|
||||||
|
builder2 = TranscriptBuilder()
|
||||||
|
builder2.load_previous(exported)
|
||||||
|
assert builder2.entry_count == 3
|
||||||
|
|
||||||
|
def test_strip_progress_preserves_compact_summary(self):
|
||||||
|
"""strip_progress_entries should keep isCompactSummary entries."""
|
||||||
|
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||||
|
stripped = strip_progress_entries(content)
|
||||||
|
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||||
|
types = [e.get("type") for e in entries]
|
||||||
|
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
|
||||||
|
compact = [e for e in entries if e.get("isCompactSummary")]
|
||||||
|
assert len(compact) == 1
|
||||||
|
|
||||||
|
def test_regular_summary_still_stripped(self):
|
||||||
|
"""Non-compact summaries should still be stripped."""
|
||||||
|
regular_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "rs1",
|
||||||
|
"summary": "Session summary",
|
||||||
|
}
|
||||||
|
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
|
||||||
|
stripped = strip_progress_entries(content)
|
||||||
|
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||||
|
types = [e.get("type") for e in entries]
|
||||||
|
assert "summary" not in types
|
||||||
|
|
||||||
|
def test_replace_entries_preserves_compact_summary(self):
|
||||||
|
"""replace_entries should preserve isCompactSummary entries."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("old")
|
||||||
|
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||||
|
builder.replace_entries(content)
|
||||||
|
assert builder.entry_count == 3
|
||||||
|
|
||||||
|
# Verify by re-exporting
|
||||||
|
exported = builder.to_jsonl()
|
||||||
|
first = json.loads(exported.strip().split("\n")[0])
|
||||||
|
assert first.get("isCompactSummary") is True
|
||||||
|
|
||||||
|
|
||||||
|
# --- _flatten_assistant_content ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlattenAssistantContent:
|
||||||
|
def test_text_blocks(self):
|
||||||
|
blocks = [
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "text", "text": "World"},
|
||||||
|
]
|
||||||
|
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||||
|
|
||||||
|
def test_tool_use_blocks(self):
|
||||||
|
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
|
||||||
|
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||||
|
|
||||||
|
def test_mixed_blocks(self):
|
||||||
|
blocks = [
|
||||||
|
{"type": "text", "text": "Let me read that."},
|
||||||
|
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
|
||||||
|
]
|
||||||
|
result = _flatten_assistant_content(blocks)
|
||||||
|
assert "Let me read that." in result
|
||||||
|
assert "[tool_use: read]" in result
|
||||||
|
|
||||||
|
def test_string_blocks(self):
|
||||||
|
"""Plain strings in the list should be included."""
|
||||||
|
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
assert _flatten_assistant_content([]) == ""
|
||||||
|
|
||||||
|
def test_tool_use_missing_name(self):
|
||||||
|
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
|
||||||
|
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
|
||||||
|
|
||||||
|
|
||||||
|
# --- _flatten_tool_result_content ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlattenToolResultContent:
|
||||||
|
def test_tool_result_with_text(self):
|
||||||
|
blocks = [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "t1",
|
||||||
|
"content": [{"type": "text", "text": "file contents here"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||||
|
|
||||||
|
def test_tool_result_with_string_content(self):
|
||||||
|
blocks = [
|
||||||
|
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
|
||||||
|
]
|
||||||
|
assert _flatten_tool_result_content(blocks) == "simple result"
|
||||||
|
|
||||||
|
def test_tool_result_with_nested_list(self):
|
||||||
|
blocks = [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "t1",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "line 1"},
|
||||||
|
{"type": "text", "text": "line 2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
|
||||||
|
|
||||||
|
def test_text_blocks(self):
|
||||||
|
blocks = [{"type": "text", "text": "some text"}]
|
||||||
|
assert _flatten_tool_result_content(blocks) == "some text"
|
||||||
|
|
||||||
|
def test_string_items(self):
|
||||||
|
assert _flatten_tool_result_content(["raw string"]) == "raw string"
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
assert _flatten_tool_result_content([]) == ""
|
||||||
|
|
||||||
|
def test_tool_result_none_text_uses_json(self):
|
||||||
|
"""Dicts without text key fall back to json.dumps."""
|
||||||
|
blocks = [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "t1",
|
||||||
|
"content": [{"type": "image", "source": "data:..."}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result = _flatten_tool_result_content(blocks)
|
||||||
|
assert "image" in result # json.dumps fallback includes the key
|
||||||
|
|
||||||
|
|
||||||
|
# --- _transcript_to_messages ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptToMessages:
|
||||||
|
def test_basic_conversation(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "hi there"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0] == {"role": "user", "content": "hello"}
|
||||||
|
assert msgs[1] == {"role": "assistant", "content": "hi there"}
|
||||||
|
|
||||||
|
def test_strips_progress_entries(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "progress",
|
||||||
|
"uuid": "p1",
|
||||||
|
"message": {"role": "user", "content": "..."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "ok"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0]["role"] == "user"
|
||||||
|
assert msgs[1]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_preserves_compact_summaries(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "Summary of previous..."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0]["content"] == "Summary of previous..."
|
||||||
|
|
||||||
|
def test_strips_regular_summary(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "s1",
|
||||||
|
"message": {"role": "user", "content": "Session summary"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 1
|
||||||
|
assert msgs[0]["content"] == "hi"
|
||||||
|
|
||||||
|
def test_skips_entries_without_role(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{"type": "user", "uuid": "u1", "message": {}},
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u2",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 1
|
||||||
|
|
||||||
|
def test_tool_result_content(self):
|
||||||
|
content = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "t1",
|
||||||
|
"content": "file contents",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 1
|
||||||
|
assert "file contents" in msgs[0]["content"]
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
assert _transcript_to_messages("") == []
|
||||||
|
assert _transcript_to_messages(" \n ") == []
|
||||||
|
|
||||||
|
def test_invalid_json_lines_skipped(self):
|
||||||
|
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
|
||||||
|
msgs = _transcript_to_messages(content)
|
||||||
|
assert len(msgs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# --- _messages_to_transcript ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessagesToTranscript:
|
||||||
|
def test_basic_roundtrip_structure(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi there"},
|
||||||
|
]
|
||||||
|
result = _messages_to_transcript(messages)
|
||||||
|
assert result.endswith("\n")
|
||||||
|
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||||
|
assert len(lines) == 2
|
||||||
|
|
||||||
|
# User entry
|
||||||
|
assert lines[0]["type"] == "user"
|
||||||
|
assert lines[0]["message"]["role"] == "user"
|
||||||
|
assert lines[0]["message"]["content"] == "hello"
|
||||||
|
assert lines[0]["parentUuid"] is None
|
||||||
|
|
||||||
|
# Assistant entry
|
||||||
|
assert lines[1]["type"] == "assistant"
|
||||||
|
assert lines[1]["message"]["role"] == "assistant"
|
||||||
|
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
|
||||||
|
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
|
||||||
|
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||||
|
|
||||||
|
def test_parent_uuid_chain(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "q1"},
|
||||||
|
{"role": "assistant", "content": "a1"},
|
||||||
|
{"role": "user", "content": "q2"},
|
||||||
|
]
|
||||||
|
result = _messages_to_transcript(messages)
|
||||||
|
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||||
|
assert lines[0]["parentUuid"] is None
|
||||||
|
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||||
|
assert lines[2]["parentUuid"] == lines[1]["uuid"]
|
||||||
|
|
||||||
|
def test_empty_messages(self):
|
||||||
|
assert _messages_to_transcript([]) == ""
|
||||||
|
|
||||||
|
def test_assistant_empty_content(self):
|
||||||
|
messages = [{"role": "assistant", "content": ""}]
|
||||||
|
result = _messages_to_transcript(messages)
|
||||||
|
entry = json.loads(result.strip())
|
||||||
|
assert entry["message"]["content"] == []
|
||||||
|
|
||||||
|
def test_output_is_valid_transcript(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "world"},
|
||||||
|
]
|
||||||
|
result = _messages_to_transcript(messages)
|
||||||
|
assert validate_transcript(result)
|
||||||
|
|
||||||
|
|
||||||
|
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptCompactionRoundtrip:
|
||||||
|
def test_content_preserved_through_roundtrip(self):
|
||||||
|
"""Messages→transcript→messages preserves content."""
|
||||||
|
original = [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
{"role": "user", "content": "Thanks"},
|
||||||
|
]
|
||||||
|
transcript = _messages_to_transcript(original)
|
||||||
|
recovered = _transcript_to_messages(transcript)
|
||||||
|
assert len(recovered) == len(original)
|
||||||
|
for orig, rec in zip(original, recovered):
|
||||||
|
assert orig["role"] == rec["role"]
|
||||||
|
assert orig["content"] == rec["content"]
|
||||||
|
|
||||||
|
def test_full_transcript_to_messages_and_back(self):
|
||||||
|
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
|
||||||
|
source = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "explain python"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Python is a programming language."}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u2",
|
||||||
|
"parentUuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "t1",
|
||||||
|
"content": "output of ls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
msgs1 = _transcript_to_messages(source)
|
||||||
|
assert len(msgs1) == 3
|
||||||
|
|
||||||
|
rebuilt = _messages_to_transcript(msgs1)
|
||||||
|
msgs2 = _transcript_to_messages(rebuilt)
|
||||||
|
assert len(msgs2) == len(msgs1)
|
||||||
|
for m1, m2 in zip(msgs1, msgs2):
|
||||||
|
assert m1["role"] == m2["role"]
|
||||||
|
# Content may differ in format (list vs string) but text is preserved
|
||||||
|
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- compact_transcript ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactTranscript:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_too_few_messages_returns_none(self):
|
||||||
|
"""Transcripts with < 2 messages can't be compacted."""
|
||||||
|
single = _make_jsonl(
|
||||||
|
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||||
|
)
|
||||||
|
result = await compact_transcript(single)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_transcript_returns_none(self):
|
||||||
|
result = await compact_transcript("")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compaction_produces_valid_transcript(self, monkeypatch):
|
||||||
|
"""When compress_context compacts, result should be valid JSONL."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from backend.util.prompt import CompressResult
|
||||||
|
|
||||||
|
mock_result = CompressResult(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Summary of conversation"},
|
||||||
|
{"role": "assistant", "content": "Acknowledged"},
|
||||||
|
],
|
||||||
|
token_count=50,
|
||||||
|
was_compacted=True,
|
||||||
|
original_token_count=5000,
|
||||||
|
messages_summarized=10,
|
||||||
|
messages_dropped=5,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._run_compression",
|
||||||
|
AsyncMock(return_value=mock_result),
|
||||||
|
)
|
||||||
|
|
||||||
|
source = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "msg1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "reply1"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u2",
|
||||||
|
"message": {"role": "user", "content": "msg2"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = await compact_transcript(source)
|
||||||
|
assert result is not None
|
||||||
|
assert validate_transcript(result)
|
||||||
|
|
||||||
|
# Verify compacted content
|
||||||
|
msgs = _transcript_to_messages(result)
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0]["content"] == "Summary of conversation"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_compaction_needed_returns_original(self, monkeypatch):
|
||||||
|
"""When compress_context says no compaction needed, return original."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from backend.util.prompt import CompressResult
|
||||||
|
|
||||||
|
mock_result = CompressResult(
|
||||||
|
messages=[], token_count=100, was_compacted=False, original_token_count=100
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._run_compression",
|
||||||
|
AsyncMock(return_value=mock_result),
|
||||||
|
)
|
||||||
|
|
||||||
|
source = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "hello"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = await compact_transcript(source)
|
||||||
|
assert result == source # Unchanged
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compression_failure_returns_none(self, monkeypatch):
|
||||||
|
"""When _run_compression raises, compact_transcript returns None."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._run_compression",
|
||||||
|
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
|
||||||
|
)
|
||||||
|
|
||||||
|
source = _make_jsonl(
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "hello"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = await compact_transcript(source)
|
||||||
|
assert result is None
|
||||||
|
|||||||
@@ -8,11 +8,15 @@ from pydantic_core import PydanticUndefined
|
|||||||
|
|
||||||
from backend.blocks._base import AnyBlockSchema
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||||
|
from backend.data import db
|
||||||
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.db_accessors import workspace_db
|
from backend.data.db_accessors import workspace_db
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
|
from backend.executor.utils import block_usage_cost
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||||
from backend.util.type import coerce_inputs_to_schema
|
from backend.util.type import coerce_inputs_to_schema
|
||||||
|
|
||||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||||
@@ -21,6 +25,26 @@ from .utils import match_credentials_to_requirements
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_credits(user_id: str) -> int:
|
||||||
|
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||||
|
if not db.is_connected():
|
||||||
|
return await get_database_manager_async_client().get_credits(user_id)
|
||||||
|
credit_model = await get_user_credit_model(user_id)
|
||||||
|
return await credit_model.get_credits(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _spend_credits(
|
||||||
|
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||||
|
) -> int:
|
||||||
|
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||||
|
if not db.is_connected():
|
||||||
|
return await get_database_manager_async_client().spend_credits(
|
||||||
|
user_id, cost, metadata
|
||||||
|
)
|
||||||
|
credit_model = await get_user_credit_model(user_id)
|
||||||
|
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_from_schema(
|
def get_inputs_from_schema(
|
||||||
input_schema: dict[str, Any],
|
input_schema: dict[str, Any],
|
||||||
exclude_fields: set[str] | None = None,
|
exclude_fields: set[str] | None = None,
|
||||||
@@ -115,6 +139,20 @@ async def execute_block(
|
|||||||
# Coerce non-matching data types to the expected input schema.
|
# Coerce non-matching data types to the expected input schema.
|
||||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||||
|
|
||||||
|
# Pre-execution credit check
|
||||||
|
cost, cost_filter = block_usage_cost(block, input_data)
|
||||||
|
has_cost = cost > 0
|
||||||
|
if has_cost:
|
||||||
|
balance = await _get_credits(user_id)
|
||||||
|
if balance < cost:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Insufficient credits to run '{block.name}'. "
|
||||||
|
"Please top up your credits to continue."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Execute the block and collect outputs
|
# Execute the block and collect outputs
|
||||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||||
async for output_name, output_data in block.execute(
|
async for output_name, output_data in block.execute(
|
||||||
@@ -123,6 +161,37 @@ async def execute_block(
|
|||||||
):
|
):
|
||||||
outputs[output_name].append(output_data)
|
outputs[output_name].append(output_data)
|
||||||
|
|
||||||
|
# Charge credits for block execution
|
||||||
|
if has_cost:
|
||||||
|
try:
|
||||||
|
await _spend_credits(
|
||||||
|
user_id=user_id,
|
||||||
|
cost=cost,
|
||||||
|
metadata=UsageTransactionMetadata(
|
||||||
|
graph_exec_id=synthetic_graph_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
block_id=block_id,
|
||||||
|
block=block.name,
|
||||||
|
input=cost_filter,
|
||||||
|
reason="copilot_block_execution",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except InsufficientBalanceError:
|
||||||
|
logger.warning(
|
||||||
|
"Post-exec credit charge failed for block %s (cost=%d)",
|
||||||
|
block.name,
|
||||||
|
cost,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Insufficient credits to complete '{block.name}'. "
|
||||||
|
"Please top up your credits to continue."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
return BlockOutputResponse(
|
return BlockOutputResponse(
|
||||||
message=f"Block '{block.name}' executed successfully",
|
message=f"Block '{block.name}' executed successfully",
|
||||||
block_id=block_id,
|
block_id=block_id,
|
||||||
@@ -133,16 +202,16 @@ async def execute_block(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except BlockError as e:
|
except BlockError as e:
|
||||||
logger.warning(f"Block execution failed: {e}")
|
logger.warning("Block execution failed: %s", e)
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Block execution failed: {e}",
|
message=f"Block execution failed: {e}",
|
||||||
error=str(e),
|
error=str(e),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Failed to execute block: {str(e)}",
|
message="An unexpected error occurred while executing the block",
|
||||||
error=str(e),
|
error=str(e),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,24 +1,202 @@
|
|||||||
"""Tests for execute_block type coercion in helpers.py.
|
"""Tests for execute_block — credit charging and type coercion."""
|
||||||
|
|
||||||
Verifies that execute_block() coerces string input values to match the block's
|
|
||||||
expected input types, mirroring the executor's validate_exec() logic.
|
|
||||||
This is critical for @@agptfile: expansion, where file content is always a string
|
|
||||||
but the block may expect structured types (e.g. list[list[str]]).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
from backend.copilot.tools.helpers import execute_block
|
from backend.copilot.tools.helpers import execute_block
|
||||||
from backend.copilot.tools.models import BlockOutputResponse
|
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||||
|
|
||||||
|
_USER = "test-user-helpers"
|
||||||
|
_SESSION = "test-session-helpers"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||||
|
"""Create a minimal mock block for execute_block()."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = BlockType.STANDARD
|
||||||
|
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
input_data: dict, **kwargs: Any
|
||||||
|
) -> AsyncIterator[tuple[str, Any]]:
|
||||||
|
yield "result", "ok"
|
||||||
|
|
||||||
|
mock.execute = _execute
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_workspace():
|
||||||
|
"""Patch workspace_db to return a mock workspace."""
|
||||||
|
mock_workspace = MagicMock()
|
||||||
|
mock_workspace.id = "ws-1"
|
||||||
|
mock_ws_db = MagicMock()
|
||||||
|
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||||
|
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Credit charging tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestExecuteBlockCreditCharging:
|
||||||
|
async def test_charges_credits_when_cost_is_positive(self):
|
||||||
|
"""Block with cost > 0 should call spend_credits after execution."""
|
||||||
|
block = _make_block()
|
||||||
|
mock_spend = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_workspace(),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers.block_usage_cost",
|
||||||
|
return_value=(10, {"key": "val"}),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._get_credits",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=100,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._spend_credits",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=mock_spend,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await execute_block(
|
||||||
|
block=block,
|
||||||
|
block_id="block-1",
|
||||||
|
input_data={"text": "hello"},
|
||||||
|
user_id=_USER,
|
||||||
|
session_id=_SESSION,
|
||||||
|
node_exec_id="exec-1",
|
||||||
|
matched_credentials={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, BlockOutputResponse)
|
||||||
|
assert result.success is True
|
||||||
|
mock_spend.assert_awaited_once()
|
||||||
|
call_kwargs = mock_spend.call_args.kwargs
|
||||||
|
assert call_kwargs["cost"] == 10
|
||||||
|
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||||
|
|
||||||
|
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||||
|
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||||
|
block = _make_block()
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_workspace(),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers.block_usage_cost",
|
||||||
|
return_value=(10, {}),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._get_credits",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=5, # balance < cost (10)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await execute_block(
|
||||||
|
block=block,
|
||||||
|
block_id="block-1",
|
||||||
|
input_data={},
|
||||||
|
user_id=_USER,
|
||||||
|
session_id=_SESSION,
|
||||||
|
node_exec_id="exec-1",
|
||||||
|
matched_credentials={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ErrorResponse)
|
||||||
|
assert "Insufficient credits" in result.message
|
||||||
|
|
||||||
|
async def test_no_charge_when_cost_is_zero(self):
|
||||||
|
"""Block with cost 0 should not call spend_credits."""
|
||||||
|
block = _make_block()
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_workspace(),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers.block_usage_cost",
|
||||||
|
return_value=(0, {}),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._get_credits",
|
||||||
|
) as mock_get_credits,
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._spend_credits",
|
||||||
|
) as mock_spend_credits,
|
||||||
|
):
|
||||||
|
result = await execute_block(
|
||||||
|
block=block,
|
||||||
|
block_id="block-1",
|
||||||
|
input_data={},
|
||||||
|
user_id=_USER,
|
||||||
|
session_id=_SESSION,
|
||||||
|
node_exec_id="exec-1",
|
||||||
|
matched_credentials={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, BlockOutputResponse)
|
||||||
|
assert result.success is True
|
||||||
|
# Credit functions should not be called at all for zero-cost blocks
|
||||||
|
mock_get_credits.assert_not_awaited()
|
||||||
|
mock_spend_credits.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_returns_error_on_post_exec_insufficient_balance(self):
|
||||||
|
"""If charging fails after execution, return ErrorResponse."""
|
||||||
|
from backend.util.exceptions import InsufficientBalanceError
|
||||||
|
|
||||||
|
block = _make_block()
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_workspace(),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers.block_usage_cost",
|
||||||
|
return_value=(10, {}),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._get_credits",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=15, # passes pre-check
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.tools.helpers._spend_credits",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=InsufficientBalanceError(
|
||||||
|
"Low balance", _USER, 5, 10
|
||||||
|
), # fails during actual charge (race with concurrent spend)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await execute_block(
|
||||||
|
block=block,
|
||||||
|
block_id="block-1",
|
||||||
|
input_data={},
|
||||||
|
user_id=_USER,
|
||||||
|
session_id=_SESSION,
|
||||||
|
node_exec_id="exec-1",
|
||||||
|
matched_credentials={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ErrorResponse)
|
||||||
|
assert "Insufficient credits" in result.message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Type coercion tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||||
"""Create a mock input_schema with model_fields matching the given annotations."""
|
"""Create a mock input_schema with model_fields matching the given annotations."""
|
||||||
schema = MagicMock()
|
schema = MagicMock()
|
||||||
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
|
|
||||||
model_fields = {}
|
model_fields = {}
|
||||||
for name, ann in annotations.items():
|
for name, ann in annotations.items():
|
||||||
field = MagicMock()
|
field = MagicMock()
|
||||||
@@ -28,7 +206,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def _make_block(
|
def _make_coerce_block(
|
||||||
block_id: str,
|
block_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
annotations: dict[str, Any],
|
annotations: dict[str, Any],
|
||||||
@@ -60,7 +238,7 @@ _TEST_USER_ID = "test-user-coerce"
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_json_string_to_nested_list():
|
async def test_coerce_json_string_to_nested_list():
|
||||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"sheets-write",
|
"sheets-write",
|
||||||
"Google Sheets Write",
|
"Google Sheets Write",
|
||||||
{"values": list[list[str]], "spreadsheet_id": str},
|
{"values": list[list[str]], "spreadsheet_id": str},
|
||||||
@@ -90,7 +268,6 @@ async def test_coerce_json_string_to_nested_list():
|
|||||||
|
|
||||||
assert isinstance(response, BlockOutputResponse)
|
assert isinstance(response, BlockOutputResponse)
|
||||||
assert response.success is True
|
assert response.success is True
|
||||||
# Verify the input was coerced from string to list[list[str]]
|
|
||||||
assert block._captured_inputs["values"] == [
|
assert block._captured_inputs["values"] == [
|
||||||
["Name", "Score"],
|
["Name", "Score"],
|
||||||
["Alice", "90"],
|
["Alice", "90"],
|
||||||
@@ -103,7 +280,7 @@ async def test_coerce_json_string_to_nested_list():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_json_string_to_list():
|
async def test_coerce_json_string_to_list():
|
||||||
"""JSON string → list[str]."""
|
"""JSON string → list[str]."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"list-block",
|
"list-block",
|
||||||
"List Block",
|
"List Block",
|
||||||
{"items": list[str]},
|
{"items": list[str]},
|
||||||
@@ -135,7 +312,7 @@ async def test_coerce_json_string_to_list():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_json_string_to_dict():
|
async def test_coerce_json_string_to_dict():
|
||||||
"""JSON string → dict[str, str]."""
|
"""JSON string → dict[str, str]."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"dict-block",
|
"dict-block",
|
||||||
"Dict Block",
|
"Dict Block",
|
||||||
{"config": dict[str, str]},
|
{"config": dict[str, str]},
|
||||||
@@ -167,7 +344,7 @@ async def test_coerce_json_string_to_dict():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_no_coercion_when_type_matches():
|
async def test_no_coercion_when_type_matches():
|
||||||
"""Already-correct types pass through without coercion."""
|
"""Already-correct types pass through without coercion."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"pass-through",
|
"pass-through",
|
||||||
"Pass Through",
|
"Pass Through",
|
||||||
{"values": list[list[str]], "name": str},
|
{"values": list[list[str]], "name": str},
|
||||||
@@ -201,7 +378,7 @@ async def test_no_coercion_when_type_matches():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_string_to_int():
|
async def test_coerce_string_to_int():
|
||||||
"""String number → int."""
|
"""String number → int."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"int-block",
|
"int-block",
|
||||||
"Int Block",
|
"Int Block",
|
||||||
{"count": int},
|
{"count": int},
|
||||||
@@ -234,7 +411,7 @@ async def test_coerce_string_to_int():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_skips_none_values():
|
async def test_coerce_skips_none_values():
|
||||||
"""None values are not coerced (they may be optional fields)."""
|
"""None values are not coerced (they may be optional fields)."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"optional-block",
|
"optional-block",
|
||||||
"Optional Block",
|
"Optional Block",
|
||||||
{"data": list[str], "label": str},
|
{"data": list[str], "label": str},
|
||||||
@@ -260,14 +437,13 @@ async def test_coerce_skips_none_values():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, BlockOutputResponse)
|
assert isinstance(response, BlockOutputResponse)
|
||||||
# 'data' was not provided, so it should not appear in captured inputs
|
|
||||||
assert "data" not in block._captured_inputs
|
assert "data" not in block._captured_inputs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_union_type_preserves_valid_member():
|
async def test_coerce_union_type_preserves_valid_member():
|
||||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"union-block",
|
"union-block",
|
||||||
"Union Block",
|
"Union Block",
|
||||||
{"content": str | list[str]},
|
{"content": str | list[str]},
|
||||||
@@ -293,7 +469,6 @@ async def test_coerce_union_type_preserves_valid_member():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, BlockOutputResponse)
|
assert isinstance(response, BlockOutputResponse)
|
||||||
# list[str] should NOT be stringified to '["a", "b"]'
|
|
||||||
assert block._captured_inputs["content"] == ["a", "b"]
|
assert block._captured_inputs["content"] == ["a", "b"]
|
||||||
assert isinstance(block._captured_inputs["content"], list)
|
assert isinstance(block._captured_inputs["content"], list)
|
||||||
|
|
||||||
@@ -301,7 +476,7 @@ async def test_coerce_union_type_preserves_valid_member():
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_coerce_inner_elements_of_generic():
|
async def test_coerce_inner_elements_of_generic():
|
||||||
"""Inner elements of generic containers are recursively coerced."""
|
"""Inner elements of generic containers are recursively coerced."""
|
||||||
block = _make_block(
|
block = _make_coerce_block(
|
||||||
"inner-coerce",
|
"inner-coerce",
|
||||||
"Inner Coerce",
|
"Inner Coerce",
|
||||||
{"values": list[str]},
|
{"values": list[str]},
|
||||||
@@ -319,7 +494,6 @@ async def test_coerce_inner_elements_of_generic():
|
|||||||
response = await execute_block(
|
response = await execute_block(
|
||||||
block=block,
|
block=block,
|
||||||
block_id="inner-coerce",
|
block_id="inner-coerce",
|
||||||
# Inner elements are ints, but target is list[str]
|
|
||||||
input_data={"values": [1, 2, 3]},
|
input_data={"values": [1, 2, 3]},
|
||||||
user_id=_TEST_USER_ID,
|
user_id=_TEST_USER_ID,
|
||||||
session_id=_TEST_SESSION_ID,
|
session_id=_TEST_SESSION_ID,
|
||||||
@@ -328,6 +502,5 @@ async def test_coerce_inner_elements_of_generic():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, BlockOutputResponse)
|
assert isinstance(response, BlockOutputResponse)
|
||||||
# Inner elements should be coerced from int to str
|
|
||||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||||
|
|||||||
@@ -512,6 +512,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
list_workspace_files = d.list_workspace_files
|
list_workspace_files = d.list_workspace_files
|
||||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||||
|
|
||||||
|
# ============ Credits ============ #
|
||||||
|
spend_credits = d.spend_credits
|
||||||
|
get_credits = d.get_credits
|
||||||
|
|
||||||
# ============ Understanding ============ #
|
# ============ Understanding ============ #
|
||||||
get_business_understanding = d.get_business_understanding
|
get_business_understanding = d.get_business_understanding
|
||||||
upsert_business_understanding = d.upsert_business_understanding
|
upsert_business_understanding = d.upsert_business_understanding
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ def _msg_tokens(msg: dict, enc) -> int:
|
|||||||
# Count tool result tokens
|
# Count tool result tokens
|
||||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
# Count text block tokens
|
||||||
|
tool_call_tokens += _tok_len(item.get("text", ""), enc)
|
||||||
elif isinstance(item, dict) and "content" in item:
|
elif isinstance(item, dict) and "content" in item:
|
||||||
# Other content types with content field
|
# Other content types with content field
|
||||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||||
@@ -145,10 +148,14 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
|||||||
if len(ids) <= max_tok:
|
if len(ids) <= max_tok:
|
||||||
return text # nothing to do
|
return text # nothing to do
|
||||||
|
|
||||||
|
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||||
|
mid = enc.encode(" … ")
|
||||||
|
if max_tok < 3:
|
||||||
|
return enc.decode(mid)
|
||||||
|
|
||||||
# Split the allowance between the two ends:
|
# Split the allowance between the two ends:
|
||||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||||
tail = max_tok - head - 1
|
tail = max_tok - head - 1
|
||||||
mid = enc.encode(" … ")
|
|
||||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||||
|
|
||||||
|
|
||||||
@@ -396,7 +403,7 @@ def validate_and_remove_orphan_tool_responses(
|
|||||||
|
|
||||||
if log_warning:
|
if log_warning:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
return _remove_orphan_tool_responses(messages, orphan_ids)
|
return _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
@@ -488,8 +495,9 @@ def _ensure_tool_pairs_intact(
|
|||||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
# This shouldn't happen in normal operation but handles edge cases
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
"Could not find assistant messages for tool_call_ids: %s. "
|
||||||
"Removing orphan tool responses."
|
"Removing orphan tool responses.",
|
||||||
|
orphan_tool_call_ids,
|
||||||
)
|
)
|
||||||
recent_messages = _remove_orphan_tool_responses(
|
recent_messages = _remove_orphan_tool_responses(
|
||||||
recent_messages, orphan_tool_call_ids
|
recent_messages, orphan_tool_call_ids
|
||||||
@@ -497,8 +505,8 @@ def _ensure_tool_pairs_intact(
|
|||||||
|
|
||||||
if messages_to_prepend:
|
if messages_to_prepend:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
|
||||||
f"tool_call/tool_response pairs"
|
len(messages_to_prepend),
|
||||||
)
|
)
|
||||||
return messages_to_prepend + recent_messages
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
@@ -686,11 +694,15 @@ async def compress_context(
|
|||||||
msgs = [summary_msg] + recent_msgs
|
msgs = [summary_msg] + recent_msgs
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||||
f"summarized {messages_summarized} messages"
|
original_count,
|
||||||
|
total_tokens(),
|
||||||
|
messages_summarized,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
logger.warning(
|
||||||
|
"Summarization failed, continuing with truncation: %s", e
|
||||||
|
)
|
||||||
# Fall through to content truncation
|
# Fall through to content truncation
|
||||||
|
|
||||||
# ---- STEP 2: Normalize content ----------------------------------------
|
# ---- STEP 2: Normalize content ----------------------------------------
|
||||||
@@ -728,6 +740,12 @@ async def compress_context(
|
|||||||
# This is more granular than dropping all old messages at once.
|
# This is more granular than dropping all old messages at once.
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
deletable: list[int] = []
|
deletable: list[int] = []
|
||||||
|
# Count assistant messages to ensure we keep at least one
|
||||||
|
assistant_indices: set[int] = {
|
||||||
|
i
|
||||||
|
for i in range(len(msgs))
|
||||||
|
if msgs[i] is not None and msgs[i].get("role") == "assistant"
|
||||||
|
}
|
||||||
for i in range(1, len(msgs) - 1):
|
for i in range(1, len(msgs) - 1):
|
||||||
msg = msgs[i]
|
msg = msgs[i]
|
||||||
if (
|
if (
|
||||||
@@ -735,6 +753,9 @@ async def compress_context(
|
|||||||
and not _is_tool_message(msg)
|
and not _is_tool_message(msg)
|
||||||
and not _is_objective_message(msg)
|
and not _is_objective_message(msg)
|
||||||
):
|
):
|
||||||
|
# Skip if this is the last remaining assistant message
|
||||||
|
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
|
||||||
|
continue
|
||||||
deletable.append(i)
|
deletable.append(i)
|
||||||
if not deletable:
|
if not deletable:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,14 +1,8 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import {
|
|
||||||
DropdownMenu,
|
|
||||||
DropdownMenuContent,
|
|
||||||
DropdownMenuItem,
|
|
||||||
DropdownMenuTrigger,
|
|
||||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
|
||||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
import { UploadSimple } from "@phosphor-icons/react";
|
||||||
import { useCallback, useRef, useState } from "react";
|
import { useCallback, useRef, useState } from "react";
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||||
@@ -92,7 +86,6 @@ export function CopilotPage() {
|
|||||||
// Delete functionality
|
// Delete functionality
|
||||||
sessionToDelete,
|
sessionToDelete,
|
||||||
isDeleting,
|
isDeleting,
|
||||||
handleDeleteClick,
|
|
||||||
handleConfirmDelete,
|
handleConfirmDelete,
|
||||||
handleCancelDelete,
|
handleCancelDelete,
|
||||||
} = useCopilotPage();
|
} = useCopilotPage();
|
||||||
@@ -148,38 +141,6 @@ export function CopilotPage() {
|
|||||||
isUploadingFiles={isUploadingFiles}
|
isUploadingFiles={isUploadingFiles}
|
||||||
droppedFiles={droppedFiles}
|
droppedFiles={droppedFiles}
|
||||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||||
headerSlot={
|
|
||||||
isMobile && sessionId ? (
|
|
||||||
<div className="flex justify-end">
|
|
||||||
<DropdownMenu>
|
|
||||||
<DropdownMenuTrigger asChild>
|
|
||||||
<button
|
|
||||||
className="rounded p-1.5 hover:bg-neutral-100"
|
|
||||||
aria-label="More actions"
|
|
||||||
>
|
|
||||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
|
||||||
</button>
|
|
||||||
</DropdownMenuTrigger>
|
|
||||||
<DropdownMenuContent align="end">
|
|
||||||
<DropdownMenuItem
|
|
||||||
onClick={() => {
|
|
||||||
const session = sessions.find(
|
|
||||||
(s) => s.id === sessionId,
|
|
||||||
);
|
|
||||||
if (session) {
|
|
||||||
handleDeleteClick(session.id, session.title);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
disabled={isDeleting}
|
|
||||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
|
||||||
>
|
|
||||||
Delete chat
|
|
||||||
</DropdownMenuItem>
|
|
||||||
</DropdownMenuContent>
|
|
||||||
</DropdownMenu>
|
|
||||||
</div>
|
|
||||||
) : undefined
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||||
import { LayoutGroup, motion } from "framer-motion";
|
import { LayoutGroup, motion } from "framer-motion";
|
||||||
import { ReactNode } from "react";
|
|
||||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||||
import { EmptySession } from "../EmptySession/EmptySession";
|
import { EmptySession } from "../EmptySession/EmptySession";
|
||||||
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
|
|||||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||||
onStop: () => void;
|
onStop: () => void;
|
||||||
isUploadingFiles?: boolean;
|
isUploadingFiles?: boolean;
|
||||||
headerSlot?: ReactNode;
|
|
||||||
/** Files dropped onto the chat window. */
|
/** Files dropped onto the chat window. */
|
||||||
droppedFiles?: File[];
|
droppedFiles?: File[];
|
||||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||||
@@ -40,7 +38,6 @@ export const ChatContainer = ({
|
|||||||
onSend,
|
onSend,
|
||||||
onStop,
|
onStop,
|
||||||
isUploadingFiles,
|
isUploadingFiles,
|
||||||
headerSlot,
|
|
||||||
droppedFiles,
|
droppedFiles,
|
||||||
onDroppedFilesConsumed,
|
onDroppedFilesConsumed,
|
||||||
}: ChatContainerProps) => {
|
}: ChatContainerProps) => {
|
||||||
@@ -63,7 +60,6 @@ export const ChatContainer = ({
|
|||||||
status={status}
|
status={status}
|
||||||
error={error}
|
error={error}
|
||||||
isLoading={isLoadingSession}
|
isLoading={isLoadingSession}
|
||||||
headerSlot={headerSlot}
|
|
||||||
sessionID={sessionId}
|
sessionID={sessionId}
|
||||||
/>
|
/>
|
||||||
<motion.div
|
<motion.div
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ interface Props {
|
|||||||
status: string;
|
status: string;
|
||||||
error: Error | undefined;
|
error: Error | undefined;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
headerSlot?: React.ReactNode;
|
|
||||||
sessionID?: string | null;
|
sessionID?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
|
|||||||
status,
|
status,
|
||||||
error,
|
error,
|
||||||
isLoading,
|
isLoading,
|
||||||
headerSlot,
|
|
||||||
sessionID,
|
sessionID,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const lastMessage = messages[messages.length - 1];
|
const lastMessage = messages[messages.length - 1];
|
||||||
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
|
|||||||
return (
|
return (
|
||||||
<Conversation className="min-h-0 flex-1">
|
<Conversation className="min-h-0 flex-1">
|
||||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||||
{headerSlot}
|
|
||||||
{isLoading && messages.length === 0 && (
|
{isLoading && messages.length === 0 && (
|
||||||
<div
|
<div
|
||||||
className="flex flex-1 items-center justify-center"
|
className="flex flex-1 items-center justify-center"
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
|
|||||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||||
|
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||||
|
|
||||||
export function ChatSidebar() {
|
export function ChatSidebar() {
|
||||||
const { state } = useSidebar();
|
const { state } = useSidebar();
|
||||||
@@ -256,11 +257,10 @@ export function ChatSidebar() {
|
|||||||
<Text variant="h3" size="body-medium">
|
<Text variant="h3" size="body-medium">
|
||||||
Your chats
|
Your chats
|
||||||
</Text>
|
</Text>
|
||||||
<div className="relative left-5 flex items-center gap-1">
|
<div className="flex items-center">
|
||||||
|
<UsageLimits />
|
||||||
<NotificationToggle />
|
<NotificationToggle />
|
||||||
<div className="relative left-1">
|
<SidebarTrigger />
|
||||||
<SidebarTrigger />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{sessionId ? (
|
{sessionId ? (
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import {
|
|||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
} from "@/components/molecules/Popover/Popover";
|
} from "@/components/molecules/Popover/Popover";
|
||||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||||
import { useCopilotUIStore } from "../../../../store";
|
import { useCopilotUIStore } from "../../../../store";
|
||||||
@@ -48,10 +49,7 @@ export function NotificationToggle() {
|
|||||||
return (
|
return (
|
||||||
<Popover>
|
<Popover>
|
||||||
<PopoverTrigger asChild>
|
<PopoverTrigger asChild>
|
||||||
<button
|
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
|
||||||
aria-label="Notification settings"
|
|
||||||
>
|
|
||||||
{!isNotificationsEnabled ? (
|
{!isNotificationsEnabled ? (
|
||||||
<BellSlash className="!size-5" />
|
<BellSlash className="!size-5" />
|
||||||
) : isSoundEnabled ? (
|
) : isSoundEnabled ? (
|
||||||
@@ -59,7 +57,7 @@ export function NotificationToggle() {
|
|||||||
) : (
|
) : (
|
||||||
<Bell className="!size-5" />
|
<Bell className="!size-5" />
|
||||||
)}
|
)}
|
||||||
</button>
|
</Button>
|
||||||
</PopoverTrigger>
|
</PopoverTrigger>
|
||||||
<PopoverContent align="start" className="w-56 p-3">
|
<PopoverContent align="start" className="w-56 p-3">
|
||||||
<div className="flex flex-col gap-3">
|
<div className="flex flex-col gap-3">
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||||
|
import {
|
||||||
|
Popover,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverTrigger,
|
||||||
|
} from "@/components/molecules/Popover/Popover";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { ChartBar } from "@phosphor-icons/react";
|
||||||
|
import { useUsageLimits } from "./useUsageLimits";
|
||||||
|
|
||||||
|
const MS_PER_MINUTE = 60_000;
|
||||||
|
const MS_PER_HOUR = 3_600_000;
|
||||||
|
|
||||||
|
function formatResetTime(resetsAt: Date | string): string {
|
||||||
|
const resetDate =
|
||||||
|
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||||
|
const now = new Date();
|
||||||
|
const diffMs = resetDate.getTime() - now.getTime();
|
||||||
|
if (diffMs <= 0) return "now";
|
||||||
|
|
||||||
|
const hours = Math.floor(diffMs / MS_PER_HOUR);
|
||||||
|
|
||||||
|
// Under 24h: show relative time ("in 4h 23m")
|
||||||
|
if (hours < 24) {
|
||||||
|
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
|
||||||
|
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||||
|
return `in ${minutes}m`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||||
|
return resetDate.toLocaleString(undefined, {
|
||||||
|
weekday: "short",
|
||||||
|
hour: "numeric",
|
||||||
|
minute: "2-digit",
|
||||||
|
timeZoneName: "short",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function UsageBar({
|
||||||
|
label,
|
||||||
|
used,
|
||||||
|
limit,
|
||||||
|
resetsAt,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
used: number;
|
||||||
|
limit: number;
|
||||||
|
resetsAt: Date | string;
|
||||||
|
}) {
|
||||||
|
if (limit <= 0) return null;
|
||||||
|
|
||||||
|
const rawPercent = (used / limit) * 100;
|
||||||
|
const percent = Math.min(100, Math.round(rawPercent));
|
||||||
|
const isHigh = percent >= 80;
|
||||||
|
const percentLabel =
|
||||||
|
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-1">
|
||||||
|
<div className="flex items-baseline justify-between">
|
||||||
|
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||||
|
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||||
|
{percentLabel}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="text-[10px] text-neutral-400">
|
||||||
|
Resets {formatResetTime(resetsAt)}
|
||||||
|
</div>
|
||||||
|
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||||
|
<div
|
||||||
|
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||||
|
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||||
|
}`}
|
||||||
|
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function UsagePanelContent({
|
||||||
|
usage,
|
||||||
|
showBillingLink = true,
|
||||||
|
}: {
|
||||||
|
usage: CoPilotUsageStatus;
|
||||||
|
showBillingLink?: boolean;
|
||||||
|
}) {
|
||||||
|
const hasDailyLimit = usage.daily.limit > 0;
|
||||||
|
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||||
|
|
||||||
|
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||||
|
return (
|
||||||
|
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-3">
|
||||||
|
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||||
|
{hasDailyLimit && (
|
||||||
|
<UsageBar
|
||||||
|
label="Today"
|
||||||
|
used={usage.daily.used}
|
||||||
|
limit={usage.daily.limit}
|
||||||
|
resetsAt={usage.daily.resets_at}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{hasWeeklyLimit && (
|
||||||
|
<UsageBar
|
||||||
|
label="This week"
|
||||||
|
used={usage.weekly.used}
|
||||||
|
limit={usage.weekly.limit}
|
||||||
|
resetsAt={usage.weekly.resets_at}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{showBillingLink && (
|
||||||
|
<a
|
||||||
|
href="/profile/credits"
|
||||||
|
className="text-[11px] text-blue-600 hover:underline"
|
||||||
|
>
|
||||||
|
Learn more about usage limits
|
||||||
|
</a>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function UsageLimits() {
|
||||||
|
const { data: usage, isLoading } = useUsageLimits();
|
||||||
|
|
||||||
|
if (isLoading || !usage) return null;
|
||||||
|
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Popover>
|
||||||
|
<PopoverTrigger asChild>
|
||||||
|
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||||
|
<ChartBar className="!size-5" weight="light" />
|
||||||
|
</Button>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent align="start" className="w-64 p-3">
|
||||||
|
<UsagePanelContent usage={usage} />
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||||
|
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||||
|
import { UsageLimits } from "../UsageLimits";
|
||||||
|
|
||||||
|
// Mock the useUsageLimits hook
|
||||||
|
const mockUseUsageLimits = vi.fn();
|
||||||
|
vi.mock("../useUsageLimits", () => ({
|
||||||
|
useUsageLimits: () => mockUseUsageLimits(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||||
|
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||||
|
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||||
|
<div>{children}</div>
|
||||||
|
),
|
||||||
|
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||||
|
<div>{children}</div>
|
||||||
|
),
|
||||||
|
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||||
|
<div>{children}</div>
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
cleanup();
|
||||||
|
mockUseUsageLimits.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
|
function makeUsage({
|
||||||
|
dailyUsed = 500,
|
||||||
|
dailyLimit = 10000,
|
||||||
|
weeklyUsed = 2000,
|
||||||
|
weeklyLimit = 50000,
|
||||||
|
}: {
|
||||||
|
dailyUsed?: number;
|
||||||
|
dailyLimit?: number;
|
||||||
|
weeklyUsed?: number;
|
||||||
|
weeklyLimit?: number;
|
||||||
|
} = {}) {
|
||||||
|
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||||
|
return {
|
||||||
|
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||||
|
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("UsageLimits", () => {
|
||||||
|
it("renders nothing while loading", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
|
||||||
|
const { container } = render(<UsageLimits />);
|
||||||
|
expect(container.innerHTML).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders nothing when no limits are configured", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
const { container } = render(<UsageLimits />);
|
||||||
|
expect(container.innerHTML).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders the usage button when limits exist", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage(),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
render(<UsageLimits />);
|
||||||
|
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("displays daily and weekly usage percentages", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
render(<UsageLimits />);
|
||||||
|
|
||||||
|
expect(screen.getByText("50% used")).toBeDefined();
|
||||||
|
expect(screen.getByText("Today")).toBeDefined();
|
||||||
|
expect(screen.getByText("This week")).toBeDefined();
|
||||||
|
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows only weekly bar when daily limit is 0", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage({
|
||||||
|
dailyLimit: 0,
|
||||||
|
weeklyUsed: 25000,
|
||||||
|
weeklyLimit: 50000,
|
||||||
|
}),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
render(<UsageLimits />);
|
||||||
|
|
||||||
|
expect(screen.getByText("This week")).toBeDefined();
|
||||||
|
expect(screen.queryByText("Today")).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("caps percentage at 100% when over limit", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
render(<UsageLimits />);
|
||||||
|
|
||||||
|
expect(screen.getByText("100% used")).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows learn more link to credits page", () => {
|
||||||
|
mockUseUsageLimits.mockReturnValue({
|
||||||
|
data: makeUsage(),
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
render(<UsageLimits />);
|
||||||
|
|
||||||
|
const link = screen.getByText("Learn more about usage limits");
|
||||||
|
expect(link).toBeDefined();
|
||||||
|
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||||
|
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
|
||||||
|
export function useUsageLimits() {
|
||||||
|
return useGetV2GetCopilotUsage({
|
||||||
|
query: {
|
||||||
|
select: (res) => res.data as CoPilotUsageStatus,
|
||||||
|
refetchInterval: 30000,
|
||||||
|
staleTime: 10000,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
getGetV2GetCopilotUsageQueryKey,
|
||||||
getGetV2GetSessionQueryKey,
|
getGetV2GetSessionQueryKey,
|
||||||
postV2CancelSessionTask,
|
postV2CancelSessionTask,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
@@ -307,6 +308,9 @@ export function useCopilotStream({
|
|||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
});
|
});
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetCopilotUsageQueryKey(),
|
||||||
|
});
|
||||||
if (status === "ready") {
|
if (status === "ready") {
|
||||||
reconnectAttemptsRef.current = 0;
|
reconnectAttemptsRef.current = 0;
|
||||||
hasShownDisconnectToast.current = false;
|
hasShownDisconnectToast.current = false;
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import {
|
|||||||
|
|
||||||
import { RefundModal } from "./RefundModal";
|
import { RefundModal } from "./RefundModal";
|
||||||
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
||||||
|
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
|
||||||
|
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
@@ -21,6 +23,26 @@ import {
|
|||||||
TableRow,
|
TableRow,
|
||||||
} from "@/components/__legacy__/ui/table";
|
} from "@/components/__legacy__/ui/table";
|
||||||
|
|
||||||
|
function CoPilotUsageSection() {
|
||||||
|
const { data: usage, isLoading } = useUsageLimits();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
if (isLoading || !usage) return null;
|
||||||
|
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="my-6 space-y-4">
|
||||||
|
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
|
||||||
|
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
|
||||||
|
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||||
|
</div>
|
||||||
|
<Button className="w-full" onClick={() => router.push("/copilot")}>
|
||||||
|
Open CoPilot
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export default function CreditsPage() {
|
export default function CreditsPage() {
|
||||||
const api = useBackendAPI();
|
const api = useBackendAPI();
|
||||||
const {
|
const {
|
||||||
@@ -237,11 +259,13 @@ export default function CreditsPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
|
{/* CoPilot Usage Limits */}
|
||||||
|
<CoPilotUsageSection />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="my-6 space-y-4">
|
<div className="my-6 space-y-4">
|
||||||
{/* Payment Portal */}
|
{/* Payment Portal */}
|
||||||
|
|
||||||
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
||||||
<p className="text-neutral-600">
|
<p className="text-neutral-600">
|
||||||
You can manage your cards and see your payment history in the
|
You can manage your cards and see your payment history in the
|
||||||
|
|||||||
@@ -1382,6 +1382,28 @@
|
|||||||
"security": [{ "HTTPBearerJWT": [] }]
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/chat/usage": {
|
||||||
|
"get": {
|
||||||
|
"tags": ["v2", "chat", "chat"],
|
||||||
|
"summary": "Get Copilot Usage",
|
||||||
|
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
|
||||||
|
"operationId": "getV2GetCopilotUsage",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/credits": {
|
"/api/credits": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v1", "credits"],
|
"tags": ["v1", "credits"],
|
||||||
@@ -8455,6 +8477,16 @@
|
|||||||
"title": "ClarifyingQuestion",
|
"title": "ClarifyingQuestion",
|
||||||
"description": "A question that needs user clarification."
|
"description": "A question that needs user clarification."
|
||||||
},
|
},
|
||||||
|
"CoPilotUsageStatus": {
|
||||||
|
"properties": {
|
||||||
|
"daily": { "$ref": "#/components/schemas/UsageWindow" },
|
||||||
|
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["daily", "weekly"],
|
||||||
|
"title": "CoPilotUsageStatus",
|
||||||
|
"description": "Current usage status for a user across all windows."
|
||||||
|
},
|
||||||
"ContentType": {
|
"ContentType": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
@@ -12190,6 +12222,16 @@
|
|||||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||||
{ "type": "null" }
|
{ "type": "null" }
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"total_prompt_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"title": "Total Prompt Tokens",
|
||||||
|
"default": 0
|
||||||
|
},
|
||||||
|
"total_completion_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"title": "Total Completion Tokens",
|
||||||
|
"default": 0
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -14587,6 +14629,25 @@
|
|||||||
"required": ["timezone"],
|
"required": ["timezone"],
|
||||||
"title": "UpdateTimezoneRequest"
|
"title": "UpdateTimezoneRequest"
|
||||||
},
|
},
|
||||||
|
"UsageWindow": {
|
||||||
|
"properties": {
|
||||||
|
"used": { "type": "integer", "title": "Used" },
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"title": "Limit",
|
||||||
|
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||||
|
},
|
||||||
|
"resets_at": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time",
|
||||||
|
"title": "Resets At"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["used", "limit", "resets_at"],
|
||||||
|
"title": "UsageWindow",
|
||||||
|
"description": "Usage within a single time window."
|
||||||
|
},
|
||||||
"UserHistoryResponse": {
|
"UserHistoryResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"history": {
|
"history": {
|
||||||
|
|||||||
@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
|
|||||||
ref={ref}
|
ref={ref}
|
||||||
data-sidebar="trigger"
|
data-sidebar="trigger"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
onClick={(event) => {
|
onClick={(event) => {
|
||||||
onClick?.(event);
|
onClick?.(event);
|
||||||
toggleSidebar();
|
toggleSidebar();
|
||||||
|
|||||||
Reference in New Issue
Block a user