Compare commits

...

49 Commits

Author SHA1 Message Date
Zamil Majdy
9684b99949 fix(backend/copilot): fix logging, token counting, and compaction edge cases
- Convert all remaining f-string logger calls to %-style across SDK files
- Fix _msg_tokens not counting Anthropic text blocks (underestimated tokens)
- Fix _truncate_middle_tokens producing wrong output when max_tok < 3
- Fix CompactionTracker.reset_for_query not clearing _compact_start event
- Add cycle guard to strip_progress_entries reparenting loop
- Replace bare except with logged exception in transcript metadata loading
2026-03-14 21:58:42 +07:00
Zamil Majdy
d00059dc94 test(backend/copilot): add tests for transcript compaction helpers
Add comprehensive unit tests for the JSONL-message conversion layer:
- _flatten_assistant_content and _flatten_tool_result_content
- _transcript_to_messages: strippable types, compact summaries, edge cases
- _messages_to_transcript: structure, parentUuid chain, validation
- Roundtrip: messages to transcript to messages content preservation
- compact_transcript: too few messages, mock compaction, no-op, failure
2026-03-14 04:34:32 +07:00
Zamil Majdy
3a14077d52 refactor(backend/copilot): remove dead delete_transcript function
delete_transcript is no longer needed — oversized transcripts are now
compacted proactively at download time via _maybe_compact_and_upload
(400KB threshold), preventing the "prompt too long" error before it
can occur. If compaction fails, resume is gracefully skipped.
2026-03-14 03:58:36 +07:00
Zamil Majdy
4446be94ae test(backend/copilot): add e2e compaction lifecycle tests
Exercises the full service.py compaction flow end-to-end:
TranscriptBuilder load → CompactionTracker state machine →
read compacted entries → replace_entries → export → roundtrip.
2026-03-14 01:38:11 +07:00
Zamil Majdy
983aed2b0a fix(backend/copilot): preserve isCompactSummary through transcript roundtrip
TranscriptEntry didn't carry isCompactSummary, so exported JSONL had
bare "type": "summary" entries. On next turn's load_previous, these
were stripped as regular summaries, losing compaction context.

- Add isCompactSummary field to TranscriptEntry
- Preserve isCompactSummary entries in load_previous, strip_progress_entries,
  and _transcript_to_messages
- Add 6 roundtrip tests verifying isCompactSummary survives export→reload
2026-03-14 01:11:56 +07:00
Zamil Majdy
66a8cf69be fix(backend/copilot): async file I/O, add TranscriptBuilder tests
- Convert read_cli_session_file to async using aiofiles (addresses
  review comment about sync I/O in async streaming loop)
- Add 10 tests for TranscriptBuilder (replace_entries, load_previous,
  append operations, message ID consistency)
- Update existing read_cli_session_file tests to async
2026-03-14 00:16:51 +07:00
Zamil Majdy
d1b8766fa4 fix(platform): restore compaction code lost in merge, address remaining review comments
The merge of feat/tracking-cost-block reverted transcript compaction code
and several review fixes from 90b7edf1f. This commit restores the lost
code and applies additional improvements requested in review:

- Restore transcript compaction functions (_transcript_to_messages,
  _messages_to_transcript, compact_transcript, _flatten_* helpers)
- Restore _maybe_compact_and_upload helper in service.py to flatten
  deep nesting (5 levels -> 2) in transcript compaction block
- Restore CLI session file reading (read_cli_session_file) for
  mid-stream compaction sync
- Restore total_tokens DRY fix (compute once, reuse in finally)
- Extract _run_compression() helper to eliminate nested try blocks
- Add STOP_REASON_END_TURN, COMPACT_MSG_ID_PREFIX, ENTRY_TYPE_MESSAGE
  named constants replacing magic strings
- Add MS_PER_MINUTE and MS_PER_HOUR constants in UsageLimits.tsx
- Add docstring explaining Monday edge case in _weekly_reset_time
2026-03-13 22:56:19 +07:00
Zamil Majdy
628b779128 fix(backend): preserve at least one assistant message during middle-out deletion
The middle-out deletion step in compress_context could remove all assistant
messages when client=None (fallback compaction), causing validate_transcript
to fail and returning None (context loss). Now skips deleting the last
remaining assistant message.
2026-03-13 22:09:28 +07:00
Zamil Majdy
90b7edf1f1 fix(backend/copilot): address review comments — top-level imports, bug fixes, refactoring
- Move all local imports to top-level (transcript.py, helpers.py,
  baseline/service.py) per project style rules
- Fix sub.get("text") bug: use `or` instead of default arg to handle
  None values in tool_result content blocks
- Extract _flatten_assistant_content and _flatten_tool_result_content
  helpers from _transcript_to_messages to reduce nesting
- Use early returns in _get_credits/_spend_credits
- Extract _fetch_counters helper in rate_limit.py to DRY the Redis
  fetch pattern between get_usage_status and check_rate_limit
- Fix DRY violation: compute total_tokens once before StreamUsage,
  reuse in finally block for session persistence
- Fix chunk.choices guard: use early continue instead of ternary
  to prevent IndexError on empty list
- Make generic error message in execute_block to avoid leaking internals
2026-03-13 19:49:28 +07:00
Zamil Majdy
8b970c4c3d Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction-v2 2026-03-13 19:25:12 +07:00
Zamil Majdy
601fed93b8 fix(backend/copilot): resolve stash conflicts in transcript.py, handle file-stat race condition 2026-03-13 19:10:07 +07:00
Zamil Majdy
e3f9fa3648 fix(platform): merge dev branch, resolve conflicts in routes_test and openapi.json 2026-03-13 19:03:20 +07:00
Zamil Majdy
809ba56f0b fix(backend/copilot): preserve structured tool_result content during transcript flattening
When flattening tool_result blocks for summarisation, dict blocks without a
"text" key were silently dropped (replaced with empty string). Now falls back
to json.dumps(sub) so JSON/structured payloads are preserved for compaction.
2026-03-13 18:30:23 +07:00
Zamil Majdy
9a467e1dba fix(platform): preserve message_count after transcript compaction, clean up imports
- Fix bug where message_count was reset to JSONL line count after
  compaction instead of preserving the original session.messages
  watermark. The JSONL line count is smaller than the original message
  count, causing the gap-fill logic to re-inject already-covered
  messages into the prompt.
- Move pathlib.Path and uuid.uuid4 imports to module level in
  transcript.py.
- Use Next.js router instead of window.location.href in credits page.
2026-03-13 18:21:15 +07:00
Zamil Majdy
0200748225 fix(platform): address PR review comments
- transcript.py: properly flatten tool_result content blocks including
  nested list structures instead of losing tool_use_id silently
- transcript_builder.py: make replace_entries atomic — parse into temp
  builder first, only swap on success
- service.py: skip resume when compaction fails (avoid re-triggering
  "Prompt is too long"), wrap upload in try/except for best-effort
- baseline/service.py: use floor of 0 instead of 1 for token estimates
- rate_limit.py: broaden Redis exception handling to catch all errors
  (timeouts, etc.) for true fail-open behavior; remove unused import
- helpers.py: return ErrorResponse on post-exec InsufficientBalanceError
  instead of swallowing; remove raw exception text from user message
- helpers_test.py: update test to expect ErrorResponse
- UsageLimits.tsx: remove dark:* classes (copilot has no dark mode yet)
2026-03-13 18:12:24 +07:00
Zamil Majdy
1704214394 fix(backend/copilot): recount message_count after transcript compaction
After compacting a transcript at download time, the message_count was
stale (from the original download), causing the gap-fill logic to miss
messages on subsequent turns. Now recount from the compacted content
before uploading.
2026-03-13 18:01:48 +07:00
Zamil Majdy
f2efd3ad7f fix(backend/copilot): address review - path traversal fix, pathlib, tests
- Use pathlib.Path.glob instead of import glob
- Add realpath + prefix check on glob results to prevent symlink escapes
- Add unit tests for _cli_project_dir, read_cli_session_file,
  _transcript_to_messages, and _messages_to_transcript
2026-03-13 17:51:45 +07:00
Zamil Majdy
ee841d1515 fix(backend/copilot): make TranscriptBuilder compaction-aware and compact oversized transcripts
TranscriptBuilder accumulated all messages including pre-compaction content,
causing uploaded transcripts to grow unbounded. When the CLI compacted
mid-stream, TranscriptBuilder kept the full uncompacted history. On the next
turn, --resume would fail with "Prompt is too long".

Fix 1 (prevent): After the CLI's PreCompact hook fires and compaction ends,
read the CLI's session file (which reflects compaction) and replace
TranscriptBuilder's entries via new replace_entries() method.

Fix 2 (mitigate): At download time, if transcript exceeds 400KB threshold,
compact it using compress_context (LLM summarization + truncation fallback)
before passing to --resume. Upload the compacted version for future turns.
2026-03-13 17:43:48 +07:00
Zamil Majdy
5966d3669d fix(platform): use round() for cache weights, accept string dates in UsageLimits 2026-03-13 17:38:25 +07:00
Zamil Majdy
c81ab1fc3b fix(backend/copilot): use adapter pattern for credit operations in executor
The CoPilot executor runs without a Prisma connection, so direct calls
to get_user_credit_model() caused "Client is not connected to the query
engine" errors. Replace with _get_credits/_spend_credits adapters that
fall back to RPC via DatabaseManagerAsyncClient when Prisma is unavailable.
Also add missing spend_credits/get_credits to DatabaseManagerAsyncClient.
2026-03-13 17:14:08 +07:00
Zamil Majdy
5446c7f18f fix(platform): consistent header icon sizing, halve rate limits
- Add size="icon" to SidebarTrigger for uniform h-9 w-9 button sizing
- Remove extra positioning wrappers (relative left-1, left-5) around header icons
- Halve daily token limit to 2.5M and weekly to 12.5M for more reasonable defaults
2026-03-13 15:55:18 +07:00
Zamil Majdy
2b0c9ba703 fix(frontend): use Button component for icon buttons, add timezone and <1% display
- UsageLimits and NotificationToggle now use <Button variant="ghost" size="icon">
  to match SidebarTrigger's padding/sizing
- Weekly reset time shows timezone abbreviation (e.g., "Mon 7:00 AM PST")
- Usage below 1% shows "<1% used" instead of "0% used" with a 1% min bar width
2026-03-13 15:42:05 +07:00
Zamil Majdy
195c7011ae fix(frontend): update usage after chat, show reset day, fix icon weight
- Invalidate usage query when stream completes so the usage bar
  updates immediately after chatting instead of waiting 30s.
- Show reset time as day/time in local timezone when over 24h away
  (e.g. "Mon 12:00 AM") instead of unclear "63h 41m".
- Use weight="light" on ChartBar icon to match other header icons.
2026-03-13 15:21:42 +07:00
Zamil Majdy
d4944fb22b fix(platform): emit StreamUsage as SSE comment, move usage to popover
StreamUsage events crashed the frontend because the Vercel AI SDK uses
z.strictObject() and rejects unknown event types. Fix by overriding
to_sse() to emit as an SSE comment (invisible to the parser). Usage
data is already recorded server-side (session DB + Redis counters).

Move usage limits from sidebar footer back to a ChartBar icon button
in the sidebar header that opens a popover on click.
2026-03-13 15:14:14 +07:00
Zamil Majdy
a5ed8fefa9 feat(copilot): cost-weighted token rate limiting with cache breakdown
- Rate limiter now uses Anthropic's cost model: cache_read at 10%,
  cache_creation at 25%, uncached and output at 100%
- Track cache_read_tokens and cache_creation_tokens separately in
  Usage model, StreamUsage response, and SDK token extraction
- Pass cache breakdown through to record_token_usage() for accurate
  weighted counting
- Add test for cost-weighted counting (10K cache_read → 1K weighted)

This makes multi-turn conversations fairer: cached system prompts
and tool schemas don't penalize users at full token cost.
2026-03-13 14:36:04 +07:00
Zamil Majdy
a52a777b29 fix(copilot): increase rate limits from 500K/5M to 5M/25M daily/weekly
A single CoPilot turn consumes ~10-15K tokens (system prompt + tool
schemas), so 500K daily only allowed ~35-50 turns which is too
restrictive for normal use.
2026-03-13 14:21:01 +07:00
Zamil Majdy
8bec7a6933 fix(frontend): move usage limits to left column on billing page
Move CoPilot Usage Limits below Automatic Refill Settings in the left
column and add "Open CoPilot" button for consistency with other sections.
2026-03-13 14:17:57 +07:00
Zamil Majdy
e73791efed fix(copilot): move usage limits to sidebar bottom, add to billing page
- Move UsageLimits from top-right headerSlot to left sidebar footer
- Show usage limits on both main copilot page and inside chat sessions
- Add CoPilot Usage Limits section to billing/credits page
- Change "Learn more about usage limits" to "Manage billing & credits"
- Remove unused headerSlot prop from ChatContainer/ChatMessagesContainer
- Clean up unused imports in CopilotPage
2026-03-13 14:08:57 +07:00
Zamil Majdy
2d161ce2b9 refactor(platform): replace per-session rate limits with daily fixed-window
Per-session limits were gameable (create new session to reset) and had
confusing reset semantics (12h inactivity TTL that refreshed on use).
Replace with daily fixed-window counter that resets at midnight UTC,
matching the weekly window pattern.

- session_token_limit → daily_token_limit (500K tokens/day)
- Redis key: copilot:usage:daily:{user_id}:{YYYY-MM-DD} with TTL
- Remove session_id from usage API endpoint and record_token_usage()
- Frontend: "Current session" → "Today", "Weekly limits" → "This week"
2026-03-13 13:22:59 +07:00
Zamil Majdy
6fc4989654 fix(copilot): include full context window in fallback token estimation
Each API call sends the complete openai_messages list (system prompt +
history + turn), so the fallback estimator should count all of it to
match what prompt_tokens would have reported.
2026-03-13 12:36:51 +07:00
Zamil Majdy
976443bf6e refactor(copilot): use tiktoken for fallback token estimation
Replace rough chars/4 heuristic with proper tiktoken tokenizer via
estimate_token_count/estimate_token_count_str from backend.util.prompt.
2026-03-13 05:24:53 +07:00
Zamil Majdy
4ceb15b3f1 fix(copilot): scope fallback token estimation to current turn only
The fallback estimator was counting the entire openai_messages history
(system prompt + all previous turns) instead of just the messages added
during the current turn. This caused overcounting and overly strict
rate limiting when providers don't return streaming usage data.
2026-03-13 03:44:30 +07:00
Zamil Majdy
3096f94996 feat(copilot): set default rate limits based on observed usage data
Session: 500K tokens (P90 session usage ~550K)
Weekly: 5M tokens (~10x heaviest observed weekly user)
2026-03-13 03:34:05 +07:00
Zamil Majdy
6f90729612 merge: resolve conflicts with dev (coerce_inputs_to_schema)
Merge dev's coerce_inputs_to_schema into execute_block alongside
credit charging. Both features coexist: coercion runs before
credit check. Test file combines both test suites.
2026-03-13 03:15:37 +07:00
Zamil Majdy
ebf89dde8b fix(copilot): include cached tokens in SDK token tracking
Anthropic's API reports cached tokens separately (cache_read_input_tokens,
cache_creation_input_tokens) from input_tokens. The previous code only read
input_tokens, undercounting total tokens for rate limiting.

OpenRouter (baseline path) already includes cached tokens in prompt_tokens
per OpenAI-compatible format — added clarifying comment.
2026-03-13 02:46:57 +07:00
Zamil Majdy
5d057e97e5 fix(copilot): move StreamUsage/StreamFinish back out of finally block
PEP 525 prohibits yielding from finally in async generators during
aclose() — doing so raises RuntimeError on client disconnect. Move
yields after try/finally where they work on normal completion and are
harmlessly unreachable on GeneratorExit.
2026-03-13 00:40:58 +07:00
Zamil Majdy
1d2f641a26 fix(copilot): move session.usage.append to finally block in SDK service
Ensures session usage persistence and rate-limit recording stay
consistent even when an exception interrupts the try block. Mirrors
the baseline service pattern.
2026-03-13 00:35:09 +07:00
Zamil Majdy
dcb71ab0b9 test(platform): add tests for usage endpoint and UsageLimits component
- 3 new tests for GET /usage endpoint (with/without session_id, config limits)
- 9 new tests for UsageLimits frontend component (loading, empty, rendering,
  percentage capping, session-only/weekly-only, link, hook args)
2026-03-13 00:30:52 +07:00
Zamil Majdy
8136b90860 fix(copilot): move StreamUsage/StreamFinish yields into finally block
On client disconnect, GeneratorExit terminates the generator after the
finally block, making yields after it unreachable. Moving them inside
finally ensures they are at least attempted.
2026-03-13 00:05:41 +07:00
Zamil Majdy
4d179a7c37 Merge branch 'dev' into feat/tracking-cost-block 2026-03-12 23:56:38 +07:00
Zamil Majdy
f78adcdc65 fix(platform): address all open review items — clock skew, token recording, generated hooks
- Fix clock skew: share single `now` timestamp across `_weekly_key` and
  `_weekly_reset_time` calls to prevent ISO week boundary race condition
- SDK: move `record_token_usage` to finally block so tokens are always
  recorded even when exceptions interrupt the stream (prevents rate limit bypass)
- Baseline: wrap `record_token_usage` in try/except so it cannot block
  session persistence
- Routes: pass `session_id=None` instead of empty string to avoid
  malformed Redis keys; `get_usage_status` now skips session lookup when
  session_id is None
- Tests: add partial None counter tests and no-session-id test
- Frontend: replace raw fetch with generated Orval hook
  (`useGetV2GetCopilotUsage`), use generated `CoPilotUsageStatus` type,
  fix `Date` vs `string` type for `resets_at`
2026-03-12 23:18:31 +07:00
Zamil Majdy
40388b7520 fix(platform): address review — orphan keys, TTL gather, dark mode, lazy logging
- Guard record_token_usage with user_id check to prevent orphan Redis keys
- Fold session TTL lookup into asyncio.gather (eliminate serial round-trip)
- Add dark: variants to UsageLimits component
- Use lazy %s formatting in logger calls instead of eager f-strings
2026-03-12 23:03:36 +07:00
Zamil Majdy
dd7be1158b test(backend): expand rate limit test coverage, fix token estimation null safety
- Add tests for: past resets_at (negative time), expired TTL fallback,
  _session_reset_from_ttl Redis error, pipeline TTL assertions,
  pipeline execute RedisError
- Fix null safety in baseline token estimation (m.get("content") or "")
- Use MagicMock for pipeline sync methods to eliminate coroutine warnings
2026-03-12 22:38:31 +07:00
Zamil Majdy
c0e59f0a6b fix(backend): address pushback on credit charging and token estimation
- Post-exec InsufficientBalanceError: return output + log warning
  instead of ErrorResponse (block already executed with side effects)
- Add token estimation fallback for providers that don't support
  stream_options include_usage (1 token ≈ 4 chars)
- Remove unnecessary # type: ignore on AsyncRedis param
2026-03-12 22:25:14 +07:00
Zamil Majdy
104d1f1bf4 fix(backend): revert to += for prompt tokens, add negative time guard
- Revert prompt token tracking back to += (sum all rounds): each API
  call bills independently, so total billed tokens is the correct
  metric for rate limiting
- Guard against negative reset time in RateLimitExceeded message
  (clock skew / stale Redis TTL)
2026-03-12 22:09:05 +07:00
Zamil Majdy
d9e9cd4c98 fix(backend): address human + Sentry review feedback on CoPilot tracking
- Fix double-counting of prompt tokens in multi-round tool calls:
  use = (not +=) since each API call's prompt_tokens includes full
  conversation history (both baseline and SDK paths)
- Fix docstring: "sliding-window" → "fixed-window" counters
- Type redis param as AsyncRedis instead of object
- Use pipeline(transaction=False) for independent INCRBY+EXPIRE
- Simplify days_until_monday calculation with `or 7`
- Flatten time_str into ternary, merge f-strings
- Parallelize Redis gets with asyncio.gather
- Document session_id=None behavior in usage endpoint
2026-03-12 21:54:31 +07:00
Zamil Majdy
ca416300ec fix(backend): address second round CodeRabbit review feedback
- Narrow except blocks to (RedisError, ConnectionError, OSError) instead
  of bare Exception to avoid hiding coding bugs
- Remove raw user_id from log messages to prevent PII leaks under Redis
  outages
- Reuse credit_model instead of fetching twice in execute_block()
- Treat post-execution InsufficientBalanceError as fatal (matches
  executor behavior) instead of silently swallowing it
2026-03-12 21:37:32 +07:00
Zamil Majdy
c589cd0c43 fix(backend): address CodeRabbit review feedback
- Derive session reset time from Redis TTL instead of hardcoded 3h
- Add description to UsageWindow.limit documenting 0 = unlimited
- Compare balance < cost instead of balance <= 0 in pre-exec check
- Document TOCTOU behavior in check_rate_limit docstring
2026-03-12 21:10:52 +07:00
Zamil Majdy
b6d863fcd2 feat(platform): add CoPilot credit charging, token tracking, and rate limiting
- Charge credits for block execution in CoPilot (matching graph executor behavior)
- Track LLM token usage for both SDK (Claude) and baseline (OpenAI) paths
- Add Redis-based per-user rate limiting with session and weekly token windows
- Expose usage status via GET /api/chat/usage endpoint
- Add frontend UsageLimits component with progress bars in CoPilot header
- Include unit tests for rate limiting and block credit charging
2026-03-12 20:50:21 +07:00
33 changed files with 3336 additions and 160 deletions

View File

@@ -27,6 +27,12 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -389,6 +397,10 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -396,6 +408,26 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -496,6 +528,17 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).

View File

@@ -1,5 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -251,6 +252,74 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -18,11 +18,13 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -36,6 +38,7 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -46,7 +49,11 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
logger = logging.getLogger(__name__)
@@ -221,6 +228,9 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -232,6 +242,7 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -242,7 +253,18 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta:
continue
@@ -411,6 +433,53 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 0
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -421,4 +490,16 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,6 +70,20 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,

View File

@@ -73,6 +73,9 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -0,0 +1,253 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
pushes to *next* Monday so the current week's window stays open.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True
)
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except Exception:
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
exc_info=True,
)

View File

@@ -0,0 +1,334 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,12 +186,29 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
class StreamError(StreamBaseResponse):

View File

@@ -198,6 +198,7 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -0,0 +1,546 @@
"""End-to-end compaction flow test.
Simulates the full service.py compaction lifecycle using real-format
JSONL session files — no SDK subprocess needed. Exercises:
1. TranscriptBuilder loads a "downloaded" transcript
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns end events
6. _read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
10. Fresh builder loads the export — roundtrip verified
"""
import asyncio
from pathlib import Path
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import strip_progress_entries
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
"""Test-only: read compacted entries from a session JSONL file.
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
entry onward, or ``None`` if no summary is found.
"""
content = Path(path).read_text()
lines = content.strip().split("\n")
compact_idx: int | None = None
parsed: list[dict] = []
raw_lines: list[str] = []
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
parsed.append(entry)
raw_lines.append(line.strip())
if compact_idx is None and entry.get("isCompactSummary"):
compact_idx = len(parsed) - 1
if compact_idx is None:
return None
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
# Pre-compaction conversation
USER_1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "What files are in this project?"},
}
ASST_1_THINKING = {
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
"stop_reason": None,
"stop_sequence": None,
},
}
ASST_1_TOOL = {
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {"command": "ls"},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
TOOL_RESULT_1 = {
"type": "user",
"uuid": "tr1",
"parentUuid": "a1-tool",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu1",
"content": "file1.py\nfile2.py",
}
],
},
}
ASST_1_TEXT = {
"type": "assistant",
"uuid": "a1-text",
"parentUuid": "tr1",
"message": {
"role": "assistant",
"id": "msg_sdk_bbb",
"type": "message",
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Progress entries (should be stripped during upload)
PROGRESS_1 = {
"type": "progress",
"uuid": "prog1",
"parentUuid": "a1-tool",
"data": {"type": "bash_progress", "stdout": "running ls..."},
}
# Second user message
USER_2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1-text",
"message": {"role": "user", "content": "Show me file1.py"},
}
ASST_2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_sdk_ccc",
"type": "message",
"content": [{"type": "text", "text": "Here is file1.py content..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# --- Compaction summary (written by CLI after context compaction) ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {
"role": "user",
"content": (
"Summary: User asked about project files. Found file1.py and file2.py. "
"User then asked to see file1.py."
),
},
}
# Post-compaction assistant response
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a3",
"parentUuid": "cs1",
"message": {
"role": "assistant",
"id": "msg_sdk_ddd",
"type": "message",
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Post-compaction user follow-up
USER_3 = {
"type": "user",
"uuid": "u3",
"parentUuid": "a3",
"message": {"role": "user", "content": "Now show file2.py"},
}
ASST_3 = {
"type": "assistant",
"uuid": "a4",
"parentUuid": "u3",
"message": {
"role": "assistant",
"id": "msg_sdk_eee",
"type": "message",
"content": [{"type": "text", "text": "Here is file2.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# ---------------------------------------------------------------------------
# E2E test
# ---------------------------------------------------------------------------
class TestCompactionE2E:
def _write_session_file(self, session_dir, entries):
"""Write a CLI session JSONL file."""
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path):
"""Simulate the complete service.py compaction flow.
Timeline:
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
2. Current turn: download → load_previous
3. User sends "Now show file2.py" → append_user
4. SDK starts streaming response
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (end events)
9. _read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 7
# --- Step 3: User sends new query ---
builder.append_user("Now show file2.py")
assert builder.entry_count == 8
# --- Step 4: SDK starts streaming ---
builder.append_assistant(
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
model="claude-sonnet-4-20250514",
)
assert builder.entry_count == 9
# --- Step 5-6: PreCompact fires, CLI writes session file ---
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
USER_3,
ASST_3,
],
)
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
# on_compact is a property returning Event.set callable
tracker.on_compact()
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
assert len(start_events) == 3
assert isinstance(start_events[0], StreamStartStep)
assert isinstance(start_events[1], StreamToolInputStart)
assert isinstance(start_events[2], StreamToolInputAvailable)
# Verify tool_call_id is set
tool_call_id = start_events[1].toolCallId
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
end_events = _run(tracker.emit_end_if_ready(session))
assert len(end_events) == 2
assert isinstance(end_events[0], StreamToolOutputAvailable)
assert isinstance(end_events[1], StreamFinishStep)
# Verify same tool_call_id
assert end_events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: _read_compacted_entries + replace_entries ---
result = _read_compacted_entries(str(session_file))
assert result is not None
compacted_dicts, compacted_jsonl = result
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted_dicts) == 4
assert compacted_dicts[0]["uuid"] == "cs1"
assert compacted_dicts[0]["isCompactSummary"] is True
# Replace builder state with compacted JSONL
old_count = builder.entry_count
builder.replace_entries(compacted_jsonl)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
# --- Step 11: More assistant messages after compaction ---
builder.append_assistant(
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
model="claude-sonnet-4-20250514",
stop_reason="end_turn",
)
assert builder.entry_count == 5
# --- Step 12: Export for upload ---
output = builder.to_jsonl()
assert output # Not empty
output_entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(output_entries) == 5
# Verify structure:
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
assert output_entries[0]["type"] == "summary"
assert output_entries[0].get("isCompactSummary") is True
assert output_entries[0]["uuid"] == "cs1"
assert output_entries[1]["uuid"] == "a3"
assert output_entries[2]["uuid"] == "u3"
assert output_entries[3]["uuid"] == "a4"
assert output_entries[4]["type"] == "assistant"
# Verify parent chain is intact
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
assert output_entries[4]["parentUuid"] == "a4" # new → a4
# --- Step 13: Roundtrip — next turn loads this export ---
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 5
# isCompactSummary survives roundtrip
output2 = builder2.to_jsonl()
first_entry = json.loads(output2.strip().split("\n")[0])
assert first_entry.get("isCompactSummary") is True
# Can append more messages
builder2.append_user("What about file3.py?")
assert builder2.entry_count == 6
final_output = builder2.to_jsonl()
last_entry = json.loads(final_output.strip().split("\n")[-1])
assert last_entry["type"] == "user"
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path):
"""Two compactions in the same session (across reset_for_query)."""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---
builder.append_user("first question")
builder.append_assistant([{"type": "text", "text": "first answer"}])
# Write session file for first compaction
first_summary = {
"type": "summary",
"uuid": "cs-first",
"isCompactSummary": True,
"message": {"role": "user", "content": "First compaction summary"},
}
first_post = {
"type": "assistant",
"uuid": "a-first",
"parentUuid": "cs-first",
"message": {"role": "assistant", "content": "first post-compact"},
}
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events1 = _run(tracker.emit_end_if_ready(session))
assert len(end_events1) == 2 # output + finish
result1_entries = _read_compacted_entries(str(file1))
assert result1_entries is not None
_, compacted1_jsonl = result1_entries
builder.replace_entries(compacted1_jsonl)
assert builder.entry_count == 2
# --- Reset for second query ---
tracker.reset_for_query()
# --- Second query with compaction ---
builder.append_user("second question")
builder.append_assistant([{"type": "text", "text": "second answer"}])
second_summary = {
"type": "summary",
"uuid": "cs-second",
"isCompactSummary": True,
"message": {"role": "user", "content": "Second compaction summary"},
}
second_post = {
"type": "assistant",
"uuid": "a-second",
"parentUuid": "cs-second",
"message": {"role": "assistant", "content": "second post-compact"},
}
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events2 = _run(tracker.emit_end_if_ready(session))
assert len(end_events2) == 2 # output + finish
result2_entries = _read_compacted_entries(str(file2))
assert result2_entries is not None
_, compacted2_jsonl = result2_entries
builder.replace_entries(compacted2_jsonl)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
Turn 1: SDK produces transcript with progress entries
Upload: strip_progress_entries removes progress, upload to cloud
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
# Strip progress for upload
stripped = strip_progress_entries(raw_content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
# Progress should be gone
assert not any(e.get("type") == "progress" for e in stripped_entries)
assert len(stripped_entries) == 7 # 8 - 1 progress
# --- Turn 2: Download stripped, load, compaction happens ---
builder = TranscriptBuilder()
builder.load_previous(stripped)
assert builder.entry_count == 7
builder.append_user("Now show file2.py")
builder.append_assistant(
[{"type": "text", "text": "Reading file2.py..."}],
model="claude-sonnet-4-20250514",
)
# CLI writes session file with compaction
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
],
)
result = _read_compacted_entries(str(session_file))
assert result is not None
_, compacted_jsonl = result
builder.replace_entries(compacted_jsonl)
# Append post-compaction message
builder.append_user("Thanks!")
output = builder.to_jsonl()
# --- Turn 3: Fresh load of Turn 2 export ---
builder3 = TranscriptBuilder()
builder3.load_previous(output)
# Should have: compact_summary + post_compact_asst + "Thanks!"
assert builder3.entry_count == 3
# Compact summary survived the full pipeline
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
assert first.get("isCompactSummary") is True
assert first["type"] == "summary"

View File

@@ -221,12 +221,12 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
return responses

View File

@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
logger.warning("Blocked tool access attempt: %s", tool_name)
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -111,7 +111,9 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -169,7 +171,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
logger.info("[SDK] Blocked background Task, user=%s", user_id)
return cast(
SyncHookJSONOutput,
_deny(
@@ -211,7 +213,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:

View File

@@ -40,11 +40,13 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -54,6 +56,7 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
@@ -75,8 +78,12 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
COMPACT_THRESHOLD_BYTES,
TranscriptDownload,
cleanup_cli_project_dir,
compact_transcript,
download_transcript,
read_cli_session_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -294,7 +301,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
return
# Clean the CLI's project directory (transcripts + tool-results).
@@ -388,7 +395,7 @@ async def _compress_messages(
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
logger.warning("[SDK] Context compression with LLM failed: %s", e)
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
@@ -624,6 +631,56 @@ async def _prepare_file_attachments(
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
async def _maybe_compact_and_upload(
dl: TranscriptDownload,
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> str:
"""Compact an oversized transcript and upload the compacted version.
Returns the (possibly compacted) transcript content, or an empty string
if compaction was needed but failed.
"""
content = dl.content
if len(content) <= COMPACT_THRESHOLD_BYTES:
return content
logger.warning(
"%s Transcript oversized (%dB > %dB), compacting",
log_prefix,
len(content),
COMPACT_THRESHOLD_BYTES,
)
compacted = await compact_transcript(content, log_prefix=log_prefix)
if not compacted:
logger.warning(
"%s Compaction failed, skipping resume for this turn", log_prefix
)
return ""
# Keep the original message_count: it reflects the number of
# session.messages covered by this transcript, which the gap-fill
# logic uses as a slice index. Counting JSONL lines would give a
# smaller number (compacted messages != session message count) and
# cause already-covered messages to be re-injected.
try:
await upload_transcript(
user_id=user_id,
session_id=session_id,
content=compacted,
message_count=dl.message_count,
log_prefix=log_prefix,
)
except Exception:
logger.warning(
"%s Failed to upload compacted transcript",
log_prefix,
exc_info=True,
)
return compacted
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -735,6 +792,14 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -827,20 +892,33 @@ async def stream_chat_completion_sdk(
is_valid,
)
if is_valid:
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
transcript_content = await _maybe_compact_and_upload(
dl,
user_id=user_id or "",
session_id=session_id,
log_prefix=log_prefix,
)
# Load previous context into builder (empty string is a no-op)
if transcript_content:
transcript_builder.load_previous(
transcript_content, log_prefix=log_prefix
)
resume_file = (
write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if transcript_content
else None
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
logger.warning("%s Transcript downloaded but invalid", log_prefix)
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"{log_prefix} No transcript available "
@@ -1110,7 +1188,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
# Log ResultMessage details and capture token usage
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1129,9 +1207,46 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with
# the CLI's compacted session file so the uploaded
# transcript reflects compaction.
compaction_events = await compaction.emit_end_if_ready(session)
for ev in compaction_events:
yield ev
if compaction_events and sdk_cwd:
cli_content = await read_cli_session_file(sdk_cwd)
if cli_content:
transcript_builder.replace_entries(
cli_content, log_prefix=log_prefix
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1325,6 +1440,27 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1389,6 +1525,48 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
@@ -1484,6 +1662,6 @@ async def _update_title_async(
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")
logger.warning("[SDK] Failed to update session title: %s", e)

View File

@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler

View File

@@ -13,10 +13,17 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
logger = logging.getLogger(__name__)
@@ -34,6 +41,11 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
@dataclass
class TranscriptDownload:
@@ -82,7 +94,11 @@ def strip_progress_entries(content: str) -> str:
parent = entry.get("parentUuid", "")
if uid:
uuid_to_parent[uid] = parent
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
if (
entry.get("type", "") in STRIPPABLE_TYPES
and uid
and not entry.get("isCompactSummary")
):
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
@@ -93,7 +109,9 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -106,7 +124,9 @@ def strip_progress_entries(content: str) -> str:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES:
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
uid = entry.get("uuid", "")
if uid in reparented:
@@ -137,32 +157,78 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
Returns ``None`` if the path would escape the projects base.
"""
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
return None
return project_dir
async def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any mid-stream compaction.
After the CLI compacts context, its session file contains the compacted
conversation. Reading this file lets ``TranscriptBuilder`` replace its
uncompacted entries with the CLI's compacted version.
"""
import aiofiles
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
if not jsonl_files:
logger.debug("[Transcript] No CLI session file in %s", project_dir)
return None
# Pick the most recently modified file (there should only be one per turn).
# Guard against races where a file is deleted between glob and stat.
candidates: list[tuple[float, Path]] = []
for p in jsonl_files:
try:
candidates.append((p.stat().st_mtime, p))
except OSError:
continue
if not candidates:
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
return None
# Resolve + prefix check to prevent symlink escapes.
session_file = max(candidates, key=lambda item: item[0])[1]
real_path = str(session_file.resolve())
if not real_path.startswith(project_dir + os.sep):
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
return None
try:
async with aiofiles.open(real_path) as f:
content = await f.read()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
real_path,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory."""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
else:
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
logger.debug("[Transcript] Project dir not found: %s", project_dir)
def write_transcript_to_tempfile(
@@ -180,7 +246,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
return None
try:
@@ -190,17 +256,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
logger.warning("[Transcript] Failed to write resume file: %s", e)
return None
@@ -344,11 +410,14 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.info(
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
)
@@ -371,10 +440,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"{log_prefix} No transcript in storage")
logger.debug("%s No transcript in storage", log_prefix)
return None
except Exception as e:
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -394,10 +463,14 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, Exception):
except FileNotFoundError:
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -405,15 +478,171 @@ async def download_transcript(
)
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
# ---------------------------------------------------------------------------
# Transcript compaction
# ---------------------------------------------------------------------------
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
# Transcripts above this byte threshold are compacted at download time.
COMPACT_THRESHOLD_BYTES = 400_000
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format."""
lines: list[str] = []
last_uuid: str | None = None
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
async def _run_compression(
messages: list[dict],
model: str,
cfg: ChatConfig,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback."""
try:
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
except Exception as e:
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
async def compact_transcript(
content: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Returns the compacted JSONL string, or ``None`` on failure.
"""
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
if not result.was_compacted:
logger.info("%s Transcript already within token budget", log_prefix)
return content
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None

View File

@@ -31,6 +31,7 @@ class TranscriptEntry(BaseModel):
uuid: str
parentUuid: str | None
message: dict[str, Any]
isCompactSummary: bool | None = None
class TranscriptBuilder:
@@ -78,10 +79,12 @@ class TranscriptBuilder:
)
continue
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
# Compaction summaries may have type "summary" but must be preserved
# so --resume can reconstruct the compacted conversation.
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
is_compact = data.get("isCompactSummary", False)
if entry_type in STRIPPABLE_TYPES and not is_compact:
continue
entry = TranscriptEntry(
@@ -89,6 +92,7 @@ class TranscriptBuilder:
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
isCompactSummary=True if is_compact else None,
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -177,6 +181,33 @@ class TranscriptBuilder:
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Replace all entries with compacted JSONL content.
Called after the CLI performs mid-stream compaction so the builder's
state reflects the compacted conversation instead of the full
pre-compaction history.
"""
prev_count = len(self._entries)
temp = TranscriptBuilder()
try:
temp.load_previous(content, log_prefix=log_prefix)
except Exception:
logger.exception(
"%s Failed to parse compacted transcript; keeping %d existing entries",
log_prefix,
prev_count,
)
return
self._entries = temp._entries
self._last_uuid = temp._last_uuid
logger.info(
"%s Replaced %d entries with %d compacted entries",
log_prefix,
prev_count,
len(self._entries),
)
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""

View File

@@ -2,14 +2,25 @@
import os
import pytest
from backend.util import json
from .transcript import (
COMPACT_MSG_ID_PREFIX,
STRIPPABLE_TYPES,
_cli_project_dir,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
read_cli_session_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
def _make_jsonl(*entries: dict) -> str:
@@ -35,6 +46,14 @@ PROGRESS_ENTRY = {
"data": {"type": "bash_progress", "stdout": "running..."},
}
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"parentUuid": None,
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous conversation..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
@@ -237,6 +256,121 @@ class TestStripProgressEntries:
# Should return just a newline (empty content stripped)
assert result.strip() == ""
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
result = _cli_project_dir("/tmp/copilot-abc")
assert result is not None
assert "projects" in result
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
# A cwd that encodes to something with .. shouldn't escape
result = _cli_project_dir("/tmp/copilot-test")
# Should return a valid path (no traversal possible with alphanum encoding)
assert result is None or result.startswith(str(projects))
# --- read_cli_session_file ---
class TestReadCliSessionFile:
@pytest.mark.asyncio
async def test_reads_session_file(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
# Create the CLI project directory structure
cwd = "/tmp/copilot-testread"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# Write a session file
session_file = project_dir / "test-session.jsonl"
session_file.write_text(json.dumps(ASST_MSG) + "\n")
result = await read_cli_session_file(cwd)
assert result is not None
assert "assistant" in result
@pytest.mark.asyncio
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
cwd = "/tmp/copilot-nofiles"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# No jsonl files
result = await read_cli_session_file(cwd)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
(tmp_path / "projects").mkdir()
result = await read_cli_session_file("/tmp/copilot-nonexistent")
assert result is None
# --- _transcript_to_messages / _messages_to_transcript ---
class TestTranscriptMessageConversion:
def test_roundtrip_preserves_roles(self):
transcript = _make_jsonl(USER_MSG, ASST_MSG)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
def test_messages_to_transcript_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result) is True
def test_strips_strippable_types(self):
transcript = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
USER_MSG,
ASST_MSG,
)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2 # progress entry skipped
def test_flattens_assistant_content_blocks(self):
asst_with_blocks = {
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "hello"},
{"type": "tool_use", "name": "bash"},
],
},
}
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
assert len(messages) == 1
assert "hello" in messages[0]["content"]
assert "[tool_use: bash]" in messages[0]["content"]
def test_empty_messages_returns_empty(self):
result = _messages_to_transcript([])
assert result == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
@@ -282,3 +416,654 @@ class TestStripProgressEntries:
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented
# --- TranscriptBuilder ---
class TestTranscriptBuilderReplaceEntries:
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
def test_replace_entries_with_valid_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
builder.append_assistant([{"type": "text", "text": "world"}])
assert builder.entry_count == 2
# Replace with compacted content (one user + one assistant)
compacted = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(compacted)
assert builder.entry_count == 2
def test_replace_entries_keeps_old_on_corrupt_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
# Corrupt content that fails to parse
builder.replace_entries("not valid json at all\n")
# Should still have old entries (load_previous skips invalid lines,
# but if ALL lines are invalid, temp builder is empty → exception path)
assert builder.entry_count >= 0 # doesn't crash
def test_replace_entries_with_empty_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
builder.replace_entries("")
# Empty content → load_previous returns early → temp is empty
# replace_entries swaps to empty (0 entries)
assert builder.entry_count == 0
def test_replace_entries_filters_strippable_types(self):
"""Strippable types (progress, file-history-snapshot) are filtered out."""
builder = TranscriptBuilder()
builder.append_user("hello")
content = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {}},
USER_MSG,
ASST_MSG,
)
builder.replace_entries(content)
# Only user + assistant should remain (progress filtered)
assert builder.entry_count == 2
def test_replace_entries_preserves_uuids(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(content)
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
first = json.loads(lines[0])
assert first["uuid"] == "u1"
class TestTranscriptBuilderBasic:
def test_append_user_and_assistant(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "hello"}])
assert builder.entry_count == 2
assert not builder.is_empty
def test_to_jsonl_empty(self):
builder = TranscriptBuilder()
assert builder.to_jsonl() == ""
assert builder.is_empty
def test_load_previous_and_append(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.load_previous(content)
assert builder.entry_count == 2
builder.append_user("new message")
assert builder.entry_count == 3
def test_consecutive_assistant_entries_share_message_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "part1"}])
builder.append_assistant([{"type": "text", "text": "part2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[2])
assert asst1["message"]["id"] == asst2["message"]["id"]
def test_non_consecutive_assistant_entries_get_new_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "response1"}])
builder.append_user("followup")
builder.append_assistant([{"type": "text", "text": "response2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[3])
assert asst1["message"]["id"] != asst2["message"]["id"]
class TestCompactSummaryRoundtrip:
"""Verify isCompactSummary survives export→reload roundtrip."""
def test_load_previous_preserves_compact_summary(self):
"""Compaction summary with type 'summary' should not be stripped."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
assert builder.entry_count == 3
def test_export_reload_preserves_compact_summary(self):
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder1 = TranscriptBuilder()
builder1.load_previous(content)
assert builder1.entry_count == 3
exported = builder1.to_jsonl()
# Verify isCompactSummary is in the exported JSONL
first_line = json.loads(exported.strip().split("\n")[0])
assert first_line.get("isCompactSummary") is True
# Reload and verify it's still preserved
builder2 = TranscriptBuilder()
builder2.load_previous(exported)
assert builder2.entry_count == 3
def test_strip_progress_preserves_compact_summary(self):
"""strip_progress_entries should keep isCompactSummary entries."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
compact = [e for e in entries if e.get("isCompactSummary")]
assert len(compact) == 1
def test_regular_summary_still_stripped(self):
"""Non-compact summaries should still be stripped."""
regular_summary = {
"type": "summary",
"uuid": "rs1",
"summary": "Session summary",
}
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" not in types
def test_replace_entries_preserves_compact_summary(self):
"""replace_entries should preserve isCompactSummary entries."""
builder = TranscriptBuilder()
builder.append_user("old")
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder.replace_entries(content)
assert builder.entry_count == 3
# Verify by re-exporting
exported = builder.to_jsonl()
first = json.loads(exported.strip().split("\n")[0])
assert first.get("isCompactSummary") is True
# --- _flatten_assistant_content ---
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: read]" in result
def test_string_blocks(self):
"""Plain strings in the list should be included."""
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_empty_list(self):
assert _flatten_assistant_content([]) == ""
def test_tool_use_missing_name(self):
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
# --- _flatten_tool_result_content ---
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
]
assert _flatten_tool_result_content(blocks) == "simple result"
def test_tool_result_with_nested_list(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [
{"type": "text", "text": "line 1"},
{"type": "text", "text": "line 2"},
],
}
]
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
def test_text_blocks(self):
blocks = [{"type": "text", "text": "some text"}]
assert _flatten_tool_result_content(blocks) == "some text"
def test_string_items(self):
assert _flatten_tool_result_content(["raw string"]) == "raw string"
def test_empty_list(self):
assert _flatten_tool_result_content([]) == ""
def test_tool_result_none_text_uses_json(self):
"""Dicts without text key fall back to json.dumps."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback includes the key
# --- _transcript_to_messages ---
class TestTranscriptToMessages:
def test_basic_conversation(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hi there"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0] == {"role": "user", "content": "hello"}
assert msgs[1] == {"role": "assistant", "content": "hi there"}
def test_strips_progress_entries(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "progress",
"uuid": "p1",
"message": {"role": "user", "content": "..."},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "ok"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["role"] == "user"
assert msgs[1]["role"] == "assistant"
def test_preserves_compact_summaries(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous..."},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of previous..."
def test_strips_regular_summary(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "s1",
"message": {"role": "user", "content": "Session summary"},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert msgs[0]["content"] == "hi"
def test_skips_entries_without_role(self):
content = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {}},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
def test_tool_result_content(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "file contents",
}
],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert "file contents" in msgs[0]["content"]
def test_empty_content(self):
assert _transcript_to_messages("") == []
assert _transcript_to_messages(" \n ") == []
def test_invalid_json_lines_skipped(self):
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
# --- _messages_to_transcript ---
class TestMessagesToTranscript:
def test_basic_roundtrip_structure(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
result = _messages_to_transcript(messages)
assert result.endswith("\n")
lines = [json.loads(line) for line in result.strip().split("\n")]
assert len(lines) == 2
# User entry
assert lines[0]["type"] == "user"
assert lines[0]["message"]["role"] == "user"
assert lines[0]["message"]["content"] == "hello"
assert lines[0]["parentUuid"] is None
# Assistant entry
assert lines[1]["type"] == "assistant"
assert lines[1]["message"]["role"] == "assistant"
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
assert lines[1]["parentUuid"] == lines[0]["uuid"]
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
]
result = _messages_to_transcript(messages)
lines = [json.loads(line) for line in result.strip().split("\n")]
assert lines[0]["parentUuid"] is None
assert lines[1]["parentUuid"] == lines[0]["uuid"]
assert lines[2]["parentUuid"] == lines[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_assistant_empty_content(self):
messages = [{"role": "assistant", "content": ""}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["message"]["content"] == []
def test_output_is_valid_transcript(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
class TestTranscriptCompactionRoundtrip:
def test_content_preserved_through_roundtrip(self):
"""Messages→transcript→messages preserves content."""
original = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "Thanks"},
]
transcript = _messages_to_transcript(original)
recovered = _transcript_to_messages(transcript)
assert len(recovered) == len(original)
for orig, rec in zip(original, recovered):
assert orig["role"] == rec["role"]
assert orig["content"] == rec["content"]
def test_full_transcript_to_messages_and_back(self):
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "explain python"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "Python is a programming language."}
],
},
},
{
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "output of ls",
}
],
},
},
)
msgs1 = _transcript_to_messages(source)
assert len(msgs1) == 3
rebuilt = _messages_to_transcript(msgs1)
msgs2 = _transcript_to_messages(rebuilt)
assert len(msgs2) == len(msgs1)
for m1, m2 in zip(msgs1, msgs2):
assert m1["role"] == m2["role"]
# Content may differ in format (list vs string) but text is preserved
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
# --- compact_transcript ---
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""Transcripts with < 2 messages can't be compacted."""
single = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
)
result = await compact_transcript(single)
assert result is None
@pytest.mark.asyncio
async def test_empty_transcript_returns_none(self):
result = await compact_transcript("")
assert result is None
@pytest.mark.asyncio
async def test_compaction_produces_valid_transcript(self, monkeypatch):
"""When compress_context compacts, result should be valid JSONL."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[
{"role": "user", "content": "Summary of conversation"},
{"role": "assistant", "content": "Acknowledged"},
],
token_count=50,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=5,
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "msg1"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "reply1"}],
},
},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "msg2"},
},
)
result = await compact_transcript(source)
assert result is not None
assert validate_transcript(result)
# Verify compacted content
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of conversation"
@pytest.mark.asyncio
async def test_no_compaction_needed_returns_original(self, monkeypatch):
"""When compress_context says no compaction needed, return original."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[], token_count=100, was_compacted=False, original_token_count=100
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result == source # Unchanged
@pytest.mark.asyncio
async def test_compression_failure_returns_none(self, monkeypatch):
"""When _run_compression raises, compact_transcript returns None."""
from unittest.mock import AsyncMock
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result is None

View File

@@ -8,11 +8,15 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
@@ -21,6 +25,26 @@ from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
@@ -115,6 +139,20 @@ async def execute_block(
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
balance = await _get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -123,6 +161,37 @@ async def execute_block(
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
logger.warning(
"Post-exec credit charge failed for block %s (cost=%d)",
block.name,
cost,
)
return ErrorResponse(
message=(
f"Insufficient credits to complete '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -133,16 +202,16 @@ async def execute_block(
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
logger.warning("Block execution failed: %s", e)
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
message="An unexpected error occurred while executing the block",
error=str(e),
session_id=session_id,
)

View File

@@ -1,24 +1,202 @@
"""Tests for execute_block type coercion in helpers.py.
Verifies that execute_block() coerces string input values to match the block's
expected input types, mirroring the executor's validate_exec() logic.
This is critical for @@agptfile: expansion, where file content is always a string
but the block may expect structured types (e.g. list[list[str]]).
"""
"""Tests for execute_block — credit charging and type coercion."""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
mock_spend = AsyncMock()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=100,
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=5, # balance < cost (10)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
async def test_returns_error_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, return ErrorResponse."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=15, # passes pre-check
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
), # fails during actual charge (race with concurrent spend)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
"""Create a mock input_schema with model_fields matching the given annotations."""
schema = MagicMock()
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
model_fields = {}
for name, ann in annotations.items():
field = MagicMock()
@@ -28,7 +206,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
return schema
def _make_block(
def _make_coerce_block(
block_id: str,
name: str,
annotations: dict[str, Any],
@@ -60,7 +238,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_block(
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
@@ -90,7 +268,6 @@ async def test_coerce_json_string_to_nested_list():
assert isinstance(response, BlockOutputResponse)
assert response.success is True
# Verify the input was coerced from string to list[list[str]]
assert block._captured_inputs["values"] == [
["Name", "Score"],
["Alice", "90"],
@@ -103,7 +280,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_block(
block = _make_coerce_block(
"list-block",
"List Block",
{"items": list[str]},
@@ -135,7 +312,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_block(
block = _make_coerce_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
@@ -167,7 +344,7 @@ async def test_coerce_json_string_to_dict():
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_block(
block = _make_coerce_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
@@ -201,7 +378,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_block(
block = _make_coerce_block(
"int-block",
"Int Block",
{"count": int},
@@ -234,7 +411,7 @@ async def test_coerce_string_to_int():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_block(
block = _make_coerce_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
@@ -260,14 +437,13 @@ async def test_coerce_skips_none_values():
)
assert isinstance(response, BlockOutputResponse)
# 'data' was not provided, so it should not appear in captured inputs
assert "data" not in block._captured_inputs
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_block(
block = _make_coerce_block(
"union-block",
"Union Block",
{"content": str | list[str]},
@@ -293,7 +469,6 @@ async def test_coerce_union_type_preserves_valid_member():
)
assert isinstance(response, BlockOutputResponse)
# list[str] should NOT be stringified to '["a", "b"]'
assert block._captured_inputs["content"] == ["a", "b"]
assert isinstance(block._captured_inputs["content"], list)
@@ -301,7 +476,7 @@ async def test_coerce_union_type_preserves_valid_member():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_block(
block = _make_coerce_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},
@@ -319,7 +494,6 @@ async def test_coerce_inner_elements_of_generic():
response = await execute_block(
block=block,
block_id="inner-coerce",
# Inner elements are ints, but target is list[str]
input_data={"values": [1, 2, 3]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
@@ -328,6 +502,5 @@ async def test_coerce_inner_elements_of_generic():
)
assert isinstance(response, BlockOutputResponse)
# Inner elements should be coerced from int to str
assert block._captured_inputs["values"] == ["1", "2", "3"]
assert all(isinstance(v, str) for v in block._captured_inputs["values"])

View File

@@ -512,6 +512,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -70,6 +70,9 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens
tool_call_tokens += _tok_len(item.get("text", ""), enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -145,10 +148,14 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
mid = enc.encode("")
if max_tok < 3:
return enc.decode(mid)
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -396,7 +403,7 @@ def validate_and_remove_orphan_tool_responses(
if log_warning:
logger.warning(
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
)
return _remove_orphan_tool_responses(messages, orphan_ids)
@@ -488,8 +495,9 @@ def _ensure_tool_pairs_intact(
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
"Could not find assistant messages for tool_call_ids: %s. "
"Removing orphan tool responses.",
orphan_tool_call_ids,
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
@@ -497,8 +505,8 @@ def _ensure_tool_pairs_intact(
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
len(messages_to_prepend),
)
return messages_to_prepend + recent_messages
@@ -686,11 +694,15 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
@@ -728,6 +740,12 @@ async def compress_context(
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
# Count assistant messages to ensure we keep at least one
assistant_indices: set[int] = {
i
for i in range(len(msgs))
if msgs[i] is not None and msgs[i].get("role") == "assistant"
}
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
@@ -735,6 +753,9 @@ async def compress_context(
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
# Skip if this is the last remaining assistant message
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
continue
deletable.append(i)
if not deletable:
break

View File

@@ -1,14 +1,8 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -92,7 +86,6 @@ export function CopilotPage() {
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -148,38 +141,6 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>

View File

@@ -2,7 +2,6 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -40,7 +38,6 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -63,7 +60,6 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -30,7 +30,6 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
@@ -256,11 +257,10 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-5 flex items-center gap-1">
<div className="flex items-center">
<UsageLimits />
<NotificationToggle />
<div className="relative left-1">
<SidebarTrigger />
</div>
<SidebarTrigger />
</div>
</div>
{sessionId ? (

View File

@@ -7,6 +7,7 @@ import {
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
@@ -48,10 +49,7 @@ export function NotificationToggle() {
return (
<Popover>
<PopoverTrigger asChild>
<button
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
aria-label="Notification settings"
>
<Button variant="ghost" size="icon" aria-label="Notification settings">
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
@@ -59,7 +57,7 @@ export function NotificationToggle() {
) : (
<Bell className="!size-5" />
)}
</button>
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">

View File

@@ -0,0 +1,146 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { useUsageLimits } from "./useUsageLimits";
const MS_PER_MINUTE = 60_000;
const MS_PER_HOUR = 3_600_000;
function formatResetTime(resetsAt: Date | string): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const now = new Date();
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / MS_PER_HOUR);
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<a
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</a>
)}
</div>
);
}
export function UsageLimits() {
const { data: usage, isLoading } = useUsageLimits();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -0,0 +1,121 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the useUsageLimits hook
const mockUseUsageLimits = vi.fn();
vi.mock("../useUsageLimits", () => ({
useUsageLimits: () => mockUseUsageLimits(),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseUsageLimits.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

View File

@@ -0,0 +1,12 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
export function useUsageLimits() {
return useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
}

View File

@@ -1,4 +1,5 @@
import {
getGetV2GetCopilotUsageQueryKey,
getGetV2GetSessionQueryKey,
postV2CancelSessionTask,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -307,6 +308,9 @@ export function useCopilotStream({
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
queryClient.invalidateQueries({
queryKey: getGetV2GetCopilotUsageQueryKey(),
});
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;

View File

@@ -11,6 +11,8 @@ import {
import { RefundModal } from "./RefundModal";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
import {
Table,
@@ -21,6 +23,26 @@ import {
TableRow,
} from "@/components/__legacy__/ui/table";
function CoPilotUsageSection() {
const { data: usage, isLoading } = useUsageLimits();
const router = useRouter();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>
<Button className="w-full" onClick={() => router.push("/copilot")}>
Open CoPilot
</Button>
</div>
);
}
export default function CreditsPage() {
const api = useBackendAPI();
const {
@@ -237,11 +259,13 @@ export default function CreditsPage() {
</Button>
)}
</form>
{/* CoPilot Usage Limits */}
<CoPilotUsageSection />
</div>
<div className="my-6 space-y-4">
{/* Payment Portal */}
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
<p className="text-neutral-600">
You can manage your cards and see your payment history in the

View File

@@ -1382,6 +1382,28 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/usage": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8455,6 +8477,16 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"CoPilotUsageStatus": {
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
},
"type": "object",
"required": ["daily", "weekly"],
"title": "CoPilotUsageStatus",
"description": "Current usage status for a user across all windows."
},
"ContentType": {
"type": "string",
"enum": [
@@ -12190,6 +12222,16 @@
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",
"default": 0
},
"total_completion_tokens": {
"type": "integer",
"title": "Total Completion Tokens",
"default": 0
}
},
"type": "object",
@@ -14587,6 +14629,25 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",
"title": "Resets At"
}
},
"type": "object",
"required": ["used", "limit", "resets_at"],
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserHistoryResponse": {
"properties": {
"history": {

View File

@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
ref={ref}
data-sidebar="trigger"
variant="ghost"
size="icon"
onClick={(event) => {
onClick?.(event);
toggleSidebar();