Compare commits

..

49 Commits

Author SHA1 Message Date
Zamil Majdy
9684b99949 fix(backend/copilot): fix logging, token counting, and compaction edge cases
- Convert all remaining f-string logger calls to %-style across SDK files
- Fix _msg_tokens not counting Anthropic text blocks (underestimated tokens)
- Fix _truncate_middle_tokens producing wrong output when max_tok < 3
- Fix CompactionTracker.reset_for_query not clearing _compact_start event
- Add cycle guard to strip_progress_entries reparenting loop
- Replace bare except with logged exception in transcript metadata loading
2026-03-14 21:58:42 +07:00
Zamil Majdy
d00059dc94 test(backend/copilot): add tests for transcript compaction helpers
Add comprehensive unit tests for the JSONL-message conversion layer:
- _flatten_assistant_content and _flatten_tool_result_content
- _transcript_to_messages: strippable types, compact summaries, edge cases
- _messages_to_transcript: structure, parentUuid chain, validation
- Roundtrip: messages to transcript to messages content preservation
- compact_transcript: too few messages, mock compaction, no-op, failure
2026-03-14 04:34:32 +07:00
Zamil Majdy
3a14077d52 refactor(backend/copilot): remove dead delete_transcript function
delete_transcript is no longer needed — oversized transcripts are now
compacted proactively at download time via _maybe_compact_and_upload
(400KB threshold), preventing the "prompt too long" error before it
can occur. If compaction fails, resume is gracefully skipped.
2026-03-14 03:58:36 +07:00
Zamil Majdy
4446be94ae test(backend/copilot): add e2e compaction lifecycle tests
Exercises the full service.py compaction flow end-to-end:
TranscriptBuilder load → CompactionTracker state machine →
read compacted entries → replace_entries → export → roundtrip.
2026-03-14 01:38:11 +07:00
Zamil Majdy
983aed2b0a fix(backend/copilot): preserve isCompactSummary through transcript roundtrip
TranscriptEntry didn't carry isCompactSummary, so exported JSONL had
bare "type": "summary" entries. On next turn's load_previous, these
were stripped as regular summaries, losing compaction context.

- Add isCompactSummary field to TranscriptEntry
- Preserve isCompactSummary entries in load_previous, strip_progress_entries,
  and _transcript_to_messages
- Add 6 roundtrip tests verifying isCompactSummary survives export→reload
2026-03-14 01:11:56 +07:00
Zamil Majdy
66a8cf69be fix(backend/copilot): async file I/O, add TranscriptBuilder tests
- Convert read_cli_session_file to async using aiofiles (addresses
  review comment about sync I/O in async streaming loop)
- Add 10 tests for TranscriptBuilder (replace_entries, load_previous,
  append operations, message ID consistency)
- Update existing read_cli_session_file tests to async
2026-03-14 00:16:51 +07:00
Zamil Majdy
d1b8766fa4 fix(platform): restore compaction code lost in merge, address remaining review comments
The merge of feat/tracking-cost-block reverted transcript compaction code
and several review fixes from 90b7edf1f. This commit restores the lost
code and applies additional improvements requested in review:

- Restore transcript compaction functions (_transcript_to_messages,
  _messages_to_transcript, compact_transcript, _flatten_* helpers)
- Restore _maybe_compact_and_upload helper in service.py to flatten
  deep nesting (5 levels -> 2) in transcript compaction block
- Restore CLI session file reading (read_cli_session_file) for
  mid-stream compaction sync
- Restore total_tokens DRY fix (compute once, reuse in finally)
- Extract _run_compression() helper to eliminate nested try blocks
- Add STOP_REASON_END_TURN, COMPACT_MSG_ID_PREFIX, ENTRY_TYPE_MESSAGE
  named constants replacing magic strings
- Add MS_PER_MINUTE and MS_PER_HOUR constants in UsageLimits.tsx
- Add docstring explaining Monday edge case in _weekly_reset_time
2026-03-13 22:56:19 +07:00
Zamil Majdy
628b779128 fix(backend): preserve at least one assistant message during middle-out deletion
The middle-out deletion step in compress_context could remove all assistant
messages when client=None (fallback compaction), causing validate_transcript
to fail and returning None (context loss). Now skips deleting the last
remaining assistant message.
2026-03-13 22:09:28 +07:00
Zamil Majdy
90b7edf1f1 fix(backend/copilot): address review comments — top-level imports, bug fixes, refactoring
- Move all local imports to top-level (transcript.py, helpers.py,
  baseline/service.py) per project style rules
- Fix sub.get("text") bug: use `or` instead of default arg to handle
  None values in tool_result content blocks
- Extract _flatten_assistant_content and _flatten_tool_result_content
  helpers from _transcript_to_messages to reduce nesting
- Use early returns in _get_credits/_spend_credits
- Extract _fetch_counters helper in rate_limit.py to DRY the Redis
  fetch pattern between get_usage_status and check_rate_limit
- Fix DRY violation: compute total_tokens once before StreamUsage,
  reuse in finally block for session persistence
- Fix chunk.choices guard: use early continue instead of ternary
  to prevent IndexError on empty list
- Make generic error message in execute_block to avoid leaking internals
2026-03-13 19:49:28 +07:00
Zamil Majdy
8b970c4c3d Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction-v2 2026-03-13 19:25:12 +07:00
Zamil Majdy
601fed93b8 fix(backend/copilot): resolve stash conflicts in transcript.py, handle file-stat race condition 2026-03-13 19:10:07 +07:00
Zamil Majdy
e3f9fa3648 fix(platform): merge dev branch, resolve conflicts in routes_test and openapi.json 2026-03-13 19:03:20 +07:00
Zamil Majdy
809ba56f0b fix(backend/copilot): preserve structured tool_result content during transcript flattening
When flattening tool_result blocks for summarisation, dict blocks without a
"text" key were silently dropped (replaced with empty string). Now falls back
to json.dumps(sub) so JSON/structured payloads are preserved for compaction.
2026-03-13 18:30:23 +07:00
Zamil Majdy
9a467e1dba fix(platform): preserve message_count after transcript compaction, clean up imports
- Fix bug where message_count was reset to JSONL line count after
  compaction instead of preserving the original session.messages
  watermark. The JSONL line count is smaller than the original message
  count, causing the gap-fill logic to re-inject already-covered
  messages into the prompt.
- Move pathlib.Path and uuid.uuid4 imports to module level in
  transcript.py.
- Use Next.js router instead of window.location.href in credits page.
2026-03-13 18:21:15 +07:00
Zamil Majdy
0200748225 fix(platform): address PR review comments
- transcript.py: properly flatten tool_result content blocks including
  nested list structures instead of losing tool_use_id silently
- transcript_builder.py: make replace_entries atomic — parse into temp
  builder first, only swap on success
- service.py: skip resume when compaction fails (avoid re-triggering
  "Prompt is too long"), wrap upload in try/except for best-effort
- baseline/service.py: use floor of 0 instead of 1 for token estimates
- rate_limit.py: broaden Redis exception handling to catch all errors
  (timeouts, etc.) for true fail-open behavior; remove unused import
- helpers.py: return ErrorResponse on post-exec InsufficientBalanceError
  instead of swallowing; remove raw exception text from user message
- helpers_test.py: update test to expect ErrorResponse
- UsageLimits.tsx: remove dark:* classes (copilot has no dark mode yet)
2026-03-13 18:12:24 +07:00
Zamil Majdy
1704214394 fix(backend/copilot): recount message_count after transcript compaction
After compacting a transcript at download time, the message_count was
stale (from the original download), causing the gap-fill logic to miss
messages on subsequent turns. Now recount from the compacted content
before uploading.
2026-03-13 18:01:48 +07:00
Zamil Majdy
f2efd3ad7f fix(backend/copilot): address review - path traversal fix, pathlib, tests
- Use pathlib.Path.glob instead of import glob
- Add realpath + prefix check on glob results to prevent symlink escapes
- Add unit tests for _cli_project_dir, read_cli_session_file,
  _transcript_to_messages, and _messages_to_transcript
2026-03-13 17:51:45 +07:00
Zamil Majdy
ee841d1515 fix(backend/copilot): make TranscriptBuilder compaction-aware and compact oversized transcripts
TranscriptBuilder accumulated all messages including pre-compaction content,
causing uploaded transcripts to grow unbounded. When the CLI compacted
mid-stream, TranscriptBuilder kept the full uncompacted history. On the next
turn, --resume would fail with "Prompt is too long".

Fix 1 (prevent): After the CLI's PreCompact hook fires and compaction ends,
read the CLI's session file (which reflects compaction) and replace
TranscriptBuilder's entries via new replace_entries() method.

Fix 2 (mitigate): At download time, if transcript exceeds 400KB threshold,
compact it using compress_context (LLM summarization + truncation fallback)
before passing to --resume. Upload the compacted version for future turns.
2026-03-13 17:43:48 +07:00
Zamil Majdy
5966d3669d fix(platform): use round() for cache weights, accept string dates in UsageLimits 2026-03-13 17:38:25 +07:00
Zamil Majdy
c81ab1fc3b fix(backend/copilot): use adapter pattern for credit operations in executor
The CoPilot executor runs without a Prisma connection, so direct calls
to get_user_credit_model() caused "Client is not connected to the query
engine" errors. Replace with _get_credits/_spend_credits adapters that
fall back to RPC via DatabaseManagerAsyncClient when Prisma is unavailable.
Also add missing spend_credits/get_credits to DatabaseManagerAsyncClient.
2026-03-13 17:14:08 +07:00
Zamil Majdy
5446c7f18f fix(platform): consistent header icon sizing, halve rate limits
- Add size="icon" to SidebarTrigger for uniform h-9 w-9 button sizing
- Remove extra positioning wrappers (relative left-1, left-5) around header icons
- Halve daily token limit to 2.5M and weekly to 12.5M for more reasonable defaults
2026-03-13 15:55:18 +07:00
Zamil Majdy
2b0c9ba703 fix(frontend): use Button component for icon buttons, add timezone and <1% display
- UsageLimits and NotificationToggle now use <Button variant="ghost" size="icon">
  to match SidebarTrigger's padding/sizing
- Weekly reset time shows timezone abbreviation (e.g., "Mon 7:00 AM PST")
- Usage below 1% shows "<1% used" instead of "0% used" with a 1% min bar width
2026-03-13 15:42:05 +07:00
Zamil Majdy
195c7011ae fix(frontend): update usage after chat, show reset day, fix icon weight
- Invalidate usage query when stream completes so the usage bar
  updates immediately after chatting instead of waiting 30s.
- Show reset time as day/time in local timezone when over 24h away
  (e.g. "Mon 12:00 AM") instead of unclear "63h 41m".
- Use weight="light" on ChartBar icon to match other header icons.
2026-03-13 15:21:42 +07:00
Zamil Majdy
d4944fb22b fix(platform): emit StreamUsage as SSE comment, move usage to popover
StreamUsage events crashed the frontend because the Vercel AI SDK uses
z.strictObject() and rejects unknown event types. Fix by overriding
to_sse() to emit as an SSE comment (invisible to the parser). Usage
data is already recorded server-side (session DB + Redis counters).

Move usage limits from sidebar footer back to a ChartBar icon button
in the sidebar header that opens a popover on click.
2026-03-13 15:14:14 +07:00
Zamil Majdy
a5ed8fefa9 feat(copilot): cost-weighted token rate limiting with cache breakdown
- Rate limiter now uses Anthropic's cost model: cache_read at 10%,
  cache_creation at 25%, uncached and output at 100%
- Track cache_read_tokens and cache_creation_tokens separately in
  Usage model, StreamUsage response, and SDK token extraction
- Pass cache breakdown through to record_token_usage() for accurate
  weighted counting
- Add test for cost-weighted counting (10K cache_read → 1K weighted)

This makes multi-turn conversations fairer: cached system prompts
and tool schemas don't penalize users at full token cost.
2026-03-13 14:36:04 +07:00
Zamil Majdy
a52a777b29 fix(copilot): increase rate limits from 500K/5M to 5M/25M daily/weekly
A single CoPilot turn consumes ~10-15K tokens (system prompt + tool
schemas), so 500K daily only allowed ~35-50 turns which is too
restrictive for normal use.
2026-03-13 14:21:01 +07:00
Zamil Majdy
8bec7a6933 fix(frontend): move usage limits to left column on billing page
Move CoPilot Usage Limits below Automatic Refill Settings in the left
column and add "Open CoPilot" button for consistency with other sections.
2026-03-13 14:17:57 +07:00
Zamil Majdy
e73791efed fix(copilot): move usage limits to sidebar bottom, add to billing page
- Move UsageLimits from top-right headerSlot to left sidebar footer
- Show usage limits on both main copilot page and inside chat sessions
- Add CoPilot Usage Limits section to billing/credits page
- Change "Learn more about usage limits" to "Manage billing & credits"
- Remove unused headerSlot prop from ChatContainer/ChatMessagesContainer
- Clean up unused imports in CopilotPage
2026-03-13 14:08:57 +07:00
Zamil Majdy
2d161ce2b9 refactor(platform): replace per-session rate limits with daily fixed-window
Per-session limits were gameable (create new session to reset) and had
confusing reset semantics (12h inactivity TTL that refreshed on use).
Replace with daily fixed-window counter that resets at midnight UTC,
matching the weekly window pattern.

- session_token_limit → daily_token_limit (500K tokens/day)
- Redis key: copilot:usage:daily:{user_id}:{YYYY-MM-DD} with TTL
- Remove session_id from usage API endpoint and record_token_usage()
- Frontend: "Current session" → "Today", "Weekly limits" → "This week"
2026-03-13 13:22:59 +07:00
Zamil Majdy
6fc4989654 fix(copilot): include full context window in fallback token estimation
Each API call sends the complete openai_messages list (system prompt +
history + turn), so the fallback estimator should count all of it to
match what prompt_tokens would have reported.
2026-03-13 12:36:51 +07:00
Zamil Majdy
976443bf6e refactor(copilot): use tiktoken for fallback token estimation
Replace rough chars/4 heuristic with proper tiktoken tokenizer via
estimate_token_count/estimate_token_count_str from backend.util.prompt.
2026-03-13 05:24:53 +07:00
Zamil Majdy
4ceb15b3f1 fix(copilot): scope fallback token estimation to current turn only
The fallback estimator was counting the entire openai_messages history
(system prompt + all previous turns) instead of just the messages added
during the current turn. This caused overcounting and overly strict
rate limiting when providers don't return streaming usage data.
2026-03-13 03:44:30 +07:00
Zamil Majdy
3096f94996 feat(copilot): set default rate limits based on observed usage data
Session: 500K tokens (P90 session usage ~550K)
Weekly: 5M tokens (~10x heaviest observed weekly user)
2026-03-13 03:34:05 +07:00
Zamil Majdy
6f90729612 merge: resolve conflicts with dev (coerce_inputs_to_schema)
Merge dev's coerce_inputs_to_schema into execute_block alongside
credit charging. Both features coexist: coercion runs before
credit check. Test file combines both test suites.
2026-03-13 03:15:37 +07:00
Zamil Majdy
ebf89dde8b fix(copilot): include cached tokens in SDK token tracking
Anthropic's API reports cached tokens separately (cache_read_input_tokens,
cache_creation_input_tokens) from input_tokens. The previous code only read
input_tokens, undercounting total tokens for rate limiting.

OpenRouter (baseline path) already includes cached tokens in prompt_tokens
per OpenAI-compatible format — added clarifying comment.
2026-03-13 02:46:57 +07:00
Zamil Majdy
5d057e97e5 fix(copilot): move StreamUsage/StreamFinish back out of finally block
PEP 525 prohibits yielding from finally in async generators during
aclose() — doing so raises RuntimeError on client disconnect. Move
yields after try/finally where they work on normal completion and are
harmlessly unreachable on GeneratorExit.
2026-03-13 00:40:58 +07:00
Zamil Majdy
1d2f641a26 fix(copilot): move session.usage.append to finally block in SDK service
Ensures session usage persistence and rate-limit recording stay
consistent even when an exception interrupts the try block. Mirrors
the baseline service pattern.
2026-03-13 00:35:09 +07:00
Zamil Majdy
dcb71ab0b9 test(platform): add tests for usage endpoint and UsageLimits component
- 3 new tests for GET /usage endpoint (with/without session_id, config limits)
- 9 new tests for UsageLimits frontend component (loading, empty, rendering,
  percentage capping, session-only/weekly-only, link, hook args)
2026-03-13 00:30:52 +07:00
Zamil Majdy
8136b90860 fix(copilot): move StreamUsage/StreamFinish yields into finally block
On client disconnect, GeneratorExit terminates the generator after the
finally block, making yields after it unreachable. Moving them inside
finally ensures they are at least attempted.
2026-03-13 00:05:41 +07:00
Zamil Majdy
4d179a7c37 Merge branch 'dev' into feat/tracking-cost-block 2026-03-12 23:56:38 +07:00
Zamil Majdy
f78adcdc65 fix(platform): address all open review items — clock skew, token recording, generated hooks
- Fix clock skew: share single `now` timestamp across `_weekly_key` and
  `_weekly_reset_time` calls to prevent ISO week boundary race condition
- SDK: move `record_token_usage` to finally block so tokens are always
  recorded even when exceptions interrupt the stream (prevents rate limit bypass)
- Baseline: wrap `record_token_usage` in try/except so it cannot block
  session persistence
- Routes: pass `session_id=None` instead of empty string to avoid
  malformed Redis keys; `get_usage_status` now skips session lookup when
  session_id is None
- Tests: add partial None counter tests and no-session-id test
- Frontend: replace raw fetch with generated Orval hook
  (`useGetV2GetCopilotUsage`), use generated `CoPilotUsageStatus` type,
  fix `Date` vs `string` type for `resets_at`
2026-03-12 23:18:31 +07:00
Zamil Majdy
40388b7520 fix(platform): address review — orphan keys, TTL gather, dark mode, lazy logging
- Guard record_token_usage with user_id check to prevent orphan Redis keys
- Fold session TTL lookup into asyncio.gather (eliminate serial round-trip)
- Add dark: variants to UsageLimits component
- Use lazy %s formatting in logger calls instead of eager f-strings
2026-03-12 23:03:36 +07:00
Zamil Majdy
dd7be1158b test(backend): expand rate limit test coverage, fix token estimation null safety
- Add tests for: past resets_at (negative time), expired TTL fallback,
  _session_reset_from_ttl Redis error, pipeline TTL assertions,
  pipeline execute RedisError
- Fix null safety in baseline token estimation (m.get("content") or "")
- Use MagicMock for pipeline sync methods to eliminate coroutine warnings
2026-03-12 22:38:31 +07:00
Zamil Majdy
c0e59f0a6b fix(backend): address pushback on credit charging and token estimation
- Post-exec InsufficientBalanceError: return output + log warning
  instead of ErrorResponse (block already executed with side effects)
- Add token estimation fallback for providers that don't support
  stream_options include_usage (1 token ≈ 4 chars)
- Remove unnecessary # type: ignore on AsyncRedis param
2026-03-12 22:25:14 +07:00
Zamil Majdy
104d1f1bf4 fix(backend): revert to += for prompt tokens, add negative time guard
- Revert prompt token tracking back to += (sum all rounds): each API
  call bills independently, so total billed tokens is the correct
  metric for rate limiting
- Guard against negative reset time in RateLimitExceeded message
  (clock skew / stale Redis TTL)
2026-03-12 22:09:05 +07:00
Zamil Majdy
d9e9cd4c98 fix(backend): address human + Sentry review feedback on CoPilot tracking
- Fix double-counting of prompt tokens in multi-round tool calls:
  use = (not +=) since each API call's prompt_tokens includes full
  conversation history (both baseline and SDK paths)
- Fix docstring: "sliding-window" → "fixed-window" counters
- Type redis param as AsyncRedis instead of object
- Use pipeline(transaction=False) for independent INCRBY+EXPIRE
- Simplify days_until_monday calculation with `or 7`
- Flatten time_str into ternary, merge f-strings
- Parallelize Redis gets with asyncio.gather
- Document session_id=None behavior in usage endpoint
2026-03-12 21:54:31 +07:00
Zamil Majdy
ca416300ec fix(backend): address second round CodeRabbit review feedback
- Narrow except blocks to (RedisError, ConnectionError, OSError) instead
  of bare Exception to avoid hiding coding bugs
- Remove raw user_id from log messages to prevent PII leaks under Redis
  outages
- Reuse credit_model instead of fetching twice in execute_block()
- Treat post-execution InsufficientBalanceError as fatal (matches
  executor behavior) instead of silently swallowing it
2026-03-12 21:37:32 +07:00
Zamil Majdy
c589cd0c43 fix(backend): address CodeRabbit review feedback
- Derive session reset time from Redis TTL instead of hardcoded 3h
- Add description to UsageWindow.limit documenting 0 = unlimited
- Compare balance < cost instead of balance <= 0 in pre-exec check
- Document TOCTOU behavior in check_rate_limit docstring
2026-03-12 21:10:52 +07:00
Zamil Majdy
b6d863fcd2 feat(platform): add CoPilot credit charging, token tracking, and rate limiting
- Charge credits for block execution in CoPilot (matching graph executor behavior)
- Track LLM token usage for both SDK (Claude) and baseline (OpenAI) paths
- Add Redis-based per-user rate limiting with session and weekly token windows
- Expose usage status via GET /api/chat/usage endpoint
- Add frontend UsageLimits component with progress bars in CoPilot header
- Include unit tests for rate limiting and block credit charging
2026-03-12 20:50:21 +07:00
79 changed files with 4000 additions and 2320 deletions

View File

@@ -0,0 +1,17 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -0,0 +1,35 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -0,0 +1,16 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -0,0 +1,29 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -0,0 +1,28 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -1,79 +0,0 @@
---
name: pr-address
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
user-invocable: true
args: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Address
## Find the PR
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
## Fetch comments (all sources)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
```
**Bots to watch for:**
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
- `coderabbitai[bot]` — automated review. Address actionable items.
## For each unaddressed comment
Address comments **one at a time**: fix → commit → push → inline reply → next.
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
## Format and commit
After fixing, format the changed code:
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
If API routes changed, regenerate the frontend client:
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
trap "kill $REST_PID 2>/dev/null" EXIT
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID 2>/dev/null; trap - EXIT
```
Never manually edit files in `src/app/api/__generated__/`.
Then commit and **push immediately** — never batch commits without pushing.
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
## The loop
```text
address comments → format → commit → push
→ re-check comments → fix new ones → push
→ wait for CI → re-check comments after CI settles
→ repeat until: all comments addressed AND CI green AND no new comments arriving
```
While CI runs, stay productive: run local tests, address remaining comments.
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.

View File

@@ -0,0 +1,31 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -1,74 +1,51 @@
---
name: pr-review
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
args: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review
# PR Review Comment Workflow
## Find the PR
## Steps
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## Read the diff
## CRITICAL: Do Not Stop
```bash
gh pr diff {N}
```
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
## Fetch existing review comments
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
Before posting anything, fetch existing inline comments to avoid duplicates:
## Rules
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
```
## What to check
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
## Output format
Every comment **must** be prefixed with `🤖` and a criticality badge:
| Tier | Badge | Meaning |
|---|---|---|
| Blocker | `🔴 **Blocker**` | Must fix before merge |
| Should Fix | `🟠 **Should Fix**` | Important improvement |
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
| Nit | `🔵 **Nit**` | Style / wording |
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
## Post inline comments
For each finding, post an inline comment on the PR (do not just write a local report):
```bash
# Get the latest commit SHA for the PR
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
# Post an inline comment on a specific file/line
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
-f body="🤖 🔴 **Blocker**: <description>" \
-f commit_id="$COMMIT_SHA" \
-f path="<file path>" \
-F line=<line number>
```
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -0,0 +1,45 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -1,85 +0,0 @@
---
name: worktree
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
metadata:
author: autogpt-team
version: "3.0.0"
---
# Worktree Setup
## Create the worktree
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
```bash
ROOT=$(git rev-parse --show-toplevel)
PARENT=$(dirname "$ROOT")
# From an existing branch
git worktree add "$PARENT/<NAME>" <branch-name>
# From a new branch off dev
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
```
## Copy environment files
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
```bash
ROOT=$(git rev-parse --show-toplevel)
TARGET="$(dirname "$ROOT")/<NAME>"
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
if [ -f "$ROOT/$envpath/.env" ]; then
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$ROOT/$envpath/.env.default" ]; then
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
fi
done
```
## Install dependencies
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
cd "$TARGET/autogpt_platform/frontend" && pnpm install
```
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
## Running the app (optional)
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd "$TARGET/autogpt_platform/backend" && poetry run app
```
## CoPilot testing
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
## Cleanup
```bash
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
```
## Alternative: Branchlet (optional)
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
```bash
branchlet create -n <name> -s <source-branch> -b <new-branch>
```

View File

@@ -60,12 +60,9 @@ AutoGPT Platform is a monorepo containing:
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits

View File

@@ -58,31 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
## Database Schema

View File

@@ -27,6 +27,12 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -389,6 +397,10 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -396,6 +408,26 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -496,6 +528,17 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).

View File

@@ -1,5 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -251,6 +252,74 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -18,11 +18,13 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -36,6 +38,7 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -46,7 +49,11 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
logger = logging.getLogger(__name__)
@@ -221,6 +228,9 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -232,6 +242,7 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -242,7 +253,18 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta:
continue
@@ -411,6 +433,53 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 0
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -421,4 +490,16 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,6 +70,20 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,

View File

@@ -73,6 +73,9 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -0,0 +1,253 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
pushes to *next* Monday so the current week's window stays open.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True
)
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except Exception:
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
exc_info=True,
)

View File

@@ -0,0 +1,334 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,12 +186,29 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
class StreamError(StreamBaseResponse):

View File

@@ -11,7 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from collections.abc import Callable
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
@@ -27,19 +27,6 @@ from ..response_model import (
logger = logging.getLogger(__name__)
@dataclass
class CompactionResult:
"""Result of emit_end_if_ready — bundles events with compaction metadata.
Eliminates the need for separate ``compaction_just_ended`` checks,
preventing TOCTOU races between the emit call and the flag read.
"""
events: list[StreamBaseResponse] = field(default_factory=list)
just_ended: bool = False
transcript_path: str = ""
# ---------------------------------------------------------------------------
# Event builders (private — use CompactionTracker or compaction_events)
# ---------------------------------------------------------------------------
@@ -190,22 +177,11 @@ class CompactionTracker:
self._start_emitted = False
self._done = False
self._tool_call_id = ""
self._transcript_path: str = ""
def on_compact(self, transcript_path: str = "") -> None:
"""Callback for the PreCompact hook. Stores transcript_path."""
if (
self._transcript_path
and transcript_path
and self._transcript_path != transcript_path
):
logger.warning(
"[Compaction] Overwriting transcript_path %s -> %s",
self._transcript_path,
transcript_path,
)
self._transcript_path = transcript_path
self._compact_start.set()
@property
def on_compact(self) -> Callable[[], None]:
"""Callback for the PreCompact hook."""
return self._compact_start.set
# ------------------------------------------------------------------
# Pre-query compaction
@@ -222,10 +198,10 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""
self._transcript_path = ""
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
"""If the PreCompact hook fired, emit start events (spinning tool)."""
@@ -236,20 +212,15 @@ class CompactionTracker:
return _start_events(self._tool_call_id)
return []
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
"""If compaction is in progress, emit end events and persist.
Returns a ``CompactionResult`` with ``just_ended=True`` and the
captured ``transcript_path`` when a compaction cycle completes.
This avoids a separate flag check (TOCTOU-safe).
"""
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
"""If compaction is in progress, emit end events and persist."""
# Yield so pending hook tasks can set compact_start
await asyncio.sleep(0)
if self._done:
return CompactionResult()
return []
if not self._start_emitted and not self._compact_start.is_set():
return CompactionResult()
return []
if self._start_emitted:
# Close the open spinner
@@ -262,12 +233,8 @@ class CompactionTracker:
COMPACTION_DONE_MSG, tool_call_id=persist_id
)
transcript_path = self._transcript_path
self._compact_start.clear()
self._start_emitted = False
self._done = True
self._transcript_path = ""
_persist(session, persist_id, COMPACTION_DONE_MSG)
return CompactionResult(
events=done_events, just_ended=True, transcript_path=transcript_path
)
return done_events

View File

@@ -195,11 +195,10 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert len(result.events) == 2
assert isinstance(result.events[0], StreamToolOutputAvailable)
assert isinstance(result.events[1], StreamFinishStep)
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 2
assert isinstance(evts[0], StreamToolOutputAvailable)
assert isinstance(evts[1], StreamFinishStep)
# Should persist
assert len(session.messages) == 2
@@ -211,32 +210,28 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
# Don't call emit_start_if_ready
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert len(result.events) == 5 # Full self-contained event
assert isinstance(result.events[0], StreamStartStep)
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 5 # Full self-contained event
assert isinstance(evts[0], StreamStartStep)
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_no_op_when_no_new_compaction(self):
async def test_emit_end_no_op_when_done(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
# Second call should be no-op (no new on_compact)
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is False
assert result2.events == []
await tracker.emit_end_if_ready(session)
# Second call should be no-op
evts = await tracker.emit_end_if_ready(session)
assert evts == []
@pytest.mark.asyncio
async def test_emit_end_no_op_when_nothing_happened(self):
tracker = CompactionTracker()
session = _make_session()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is False
assert result.events == []
evts = await tracker.emit_end_if_ready(session)
assert evts == []
def test_emit_pre_query(self):
tracker = CompactionTracker()
@@ -251,29 +246,20 @@ class TestCompactionTracker:
tracker._done = True
tracker._start_emitted = True
tracker._tool_call_id = "old"
tracker._transcript_path = "/some/path"
tracker.reset_for_query()
assert tracker._done is False
assert tracker._start_emitted is False
assert tracker._tool_call_id == ""
assert tracker._transcript_path == ""
@pytest.mark.asyncio
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
"""After pre-query compaction, SDK compaction is blocked until
reset_for_query is called."""
async def test_pre_query_blocks_sdk_compaction(self):
"""After pre-query compaction, SDK compaction events are suppressed."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.on_compact()
# _done is True so emit_start_if_ready is blocked
evts = tracker.emit_start_if_ready()
assert evts == []
# Reset clears _done, allowing subsequent compaction
tracker.reset_for_query()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3
assert evts == [] # _done blocks it
@pytest.mark.asyncio
async def test_reset_allows_new_compaction(self):
@@ -293,9 +279,9 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
start_evts = tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
end_evts = await tracker.emit_end_if_ready(session)
start_evt = start_evts[1]
end_evt = result.events[0]
end_evt = end_evts[0]
assert isinstance(start_evt, StreamToolInputStart)
assert isinstance(end_evt, StreamToolOutputAvailable)
assert start_evt.toolCallId == end_evt.toolCallId
@@ -303,105 +289,3 @@ class TestCompactionTracker:
tool_calls = session.messages[0].tool_calls
assert tool_calls is not None
assert tool_calls[0]["id"] == start_evt.toolCallId
@pytest.mark.asyncio
async def test_multiple_compactions_within_query(self):
"""Two mid-stream compactions within a single query both trigger."""
tracker = CompactionTracker()
session = _make_session()
# First compaction cycle
tracker.on_compact("/path/1")
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
assert len(result1.events) == 2
assert result1.transcript_path == "/path/1"
# Second compaction cycle (should NOT be blocked — _done resets
# because emit_end_if_ready sets it True, but the next on_compact
# + emit_start_if_ready checks !_done which IS True now.
# So we need reset_for_query between queries, but within a single
# query multiple compactions work because _done blocks emit_start
# until the next message arrives, at which point emit_end detects it)
#
# Actually: _done=True blocks emit_start_if_ready, so we need
# the stream loop to reset. In practice service.py doesn't call
# reset between compactions within the same query — let's verify
# the actual behavior.
tracker.on_compact("/path/2")
# _done is True from first compaction, so start is blocked
start_evts = tracker.emit_start_if_ready()
assert start_evts == []
# But emit_end returns no-op because _done is True
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is False
@pytest.mark.asyncio
async def test_multiple_compactions_with_intervening_message(self):
"""Multiple compactions work when the stream loop processes messages between them.
In the real service.py flow:
1. PreCompact fires → on_compact()
2. emit_start shows spinner
3. Next message arrives → emit_end completes compaction (_done=True)
4. Stream continues processing messages...
5. If a second PreCompact fires, _done=True blocks emit_start
6. But the next message triggers emit_end, which sees _done=True → no-op
7. The stream loop needs to detect this and handle accordingly
The actual flow for multiple compactions within a query requires
_done to be cleared between them. The service.py code uses
CompactionResult.just_ended to trigger replace_entries, and _done
stays True until reset_for_query.
"""
tracker = CompactionTracker()
session = _make_session()
# First compaction
tracker.on_compact("/path/1")
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
assert result1.transcript_path == "/path/1"
# Simulate reset between queries
tracker.reset_for_query()
# Second compaction in new query
tracker.on_compact("/path/2")
start_evts = tracker.emit_start_if_ready()
assert len(start_evts) == 3
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is True
assert result2.transcript_path == "/path/2"
def test_on_compact_stores_transcript_path(self):
tracker = CompactionTracker()
tracker.on_compact("/some/path.jsonl")
assert tracker._transcript_path == "/some/path.jsonl"
@pytest.mark.asyncio
async def test_emit_end_returns_transcript_path(self):
"""CompactionResult includes the transcript_path from on_compact."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact("/my/session.jsonl")
tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert result.transcript_path == "/my/session.jsonl"
# transcript_path is cleared after emit_end
assert tracker._transcript_path == ""
@pytest.mark.asyncio
async def test_emit_end_clears_transcript_path(self):
"""After emit_end, _transcript_path is reset so it doesn't leak to
subsequent non-compaction emit_end calls."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact("/first/path.jsonl")
tracker.emit_start_if_ready()
await tracker.emit_end_if_ready(session)
# After compaction, _transcript_path is cleared
assert tracker._transcript_path == ""

View File

@@ -7,8 +7,8 @@ JSONL session files — no SDK subprocess needed. Exercises:
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns CompactionResult
6. read_compacted_entries() reads the CLI session file
5. Message after that → emit_end_if_ready() returns end events
6. _read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
@@ -16,6 +16,7 @@ JSONL session files — no SDK subprocess needed. Exercises:
"""
import asyncio
from pathlib import Path
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
@@ -26,10 +27,7 @@ from backend.copilot.response_model import (
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import (
read_compacted_entries,
strip_progress_entries,
)
from backend.copilot.sdk.transcript import strip_progress_entries
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
@@ -43,6 +41,32 @@ def _run(coro):
return asyncio.run(coro)
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
"""Test-only: read compacted entries from a session JSONL file.
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
entry onward, or ``None`` if no summary is found.
"""
content = Path(path).read_text()
lines = content.strip().split("\n")
compact_idx: int | None = None
parsed: list[dict] = []
raw_lines: list[str] = []
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
parsed.append(entry)
raw_lines.append(line.strip())
if compact_idx is None and entry.get("isCompactSummary"):
compact_idx = len(parsed) - 1
if compact_idx is None:
return None
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
@@ -205,7 +229,7 @@ class TestCompactionE2E:
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
def test_full_compaction_lifecycle(self, tmp_path):
"""Simulate the complete service.py compaction flow.
Timeline:
@@ -216,18 +240,14 @@ class TestCompactionE2E:
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (CompactionResult)
9. read_compacted_entries reads the session file
8. Following message → emit_end (end events)
9. _read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
# --- Setup CLI projects directory ---
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
@@ -276,7 +296,8 @@ class TestCompactionE2E:
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
tracker.on_compact(str(session_file))
# on_compact is a property returning Event.set callable
tracker.on_compact()
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
@@ -290,31 +311,30 @@ class TestCompactionE2E:
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
result = _run(tracker.emit_end_if_ready(session))
assert result.just_ended is True
assert result.transcript_path == str(session_file)
assert len(result.events) == 2
assert isinstance(result.events[0], StreamToolOutputAvailable)
assert isinstance(result.events[1], StreamFinishStep)
end_events = _run(tracker.emit_end_if_ready(session))
assert len(end_events) == 2
assert isinstance(end_events[0], StreamToolOutputAvailable)
assert isinstance(end_events[1], StreamFinishStep)
# Verify same tool_call_id
assert result.events[0].toolCallId == tool_call_id
assert end_events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: read_compacted_entries + replace_entries ---
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
# --- Step 10: _read_compacted_entries + replace_entries ---
result = _read_compacted_entries(str(session_file))
assert result is not None
compacted_dicts, compacted_jsonl = result
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted) == 4
assert compacted[0]["uuid"] == "cs1"
assert compacted[0]["isCompactSummary"] is True
assert len(compacted_dicts) == 4
assert compacted_dicts[0]["uuid"] == "cs1"
assert compacted_dicts[0]["isCompactSummary"] is True
# Replace builder state with compacted entries
# Replace builder state with compacted JSONL
old_count = builder.entry_count
builder.replace_entries(compacted)
builder.replace_entries(compacted_jsonl)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
@@ -367,13 +387,10 @@ class TestCompactionE2E:
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
def test_double_compaction_within_session(self, tmp_path):
"""Two compactions in the same session (across reset_for_query)."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
@@ -399,14 +416,15 @@ class TestCompactionE2E:
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact(str(file1))
tracker.on_compact()
tracker.emit_start_if_ready()
result1 = _run(tracker.emit_end_if_ready(session))
assert result1.just_ended is True
end_events1 = _run(tracker.emit_end_if_ready(session))
assert len(end_events1) == 2 # output + finish
compacted1 = read_compacted_entries(str(file1))
assert compacted1 is not None
builder.replace_entries(compacted1)
result1_entries = _read_compacted_entries(str(file1))
assert result1_entries is not None
_, compacted1_jsonl = result1_entries
builder.replace_entries(compacted1_jsonl)
assert builder.entry_count == 2
# --- Reset for second query ---
@@ -431,14 +449,15 @@ class TestCompactionE2E:
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact(str(file2))
tracker.on_compact()
tracker.emit_start_if_ready()
result2 = _run(tracker.emit_end_if_ready(session))
assert result2.just_ended is True
end_events2 = _run(tracker.emit_end_if_ready(session))
assert len(end_events2) == 2 # output + finish
compacted2 = read_compacted_entries(str(file2))
assert compacted2 is not None
builder.replace_entries(compacted2)
result2_entries = _read_compacted_entries(str(file2))
assert result2_entries is not None
_, compacted2_jsonl = result2_entries
builder.replace_entries(compacted2_jsonl)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
@@ -447,9 +466,7 @@ class TestCompactionE2E:
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(
self, tmp_path, monkeypatch
):
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
@@ -458,11 +475,8 @@ class TestCompactionE2E:
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
@@ -511,9 +525,10 @@ class TestCompactionE2E:
],
)
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
builder.replace_entries(compacted)
result = _read_compacted_entries(str(session_file))
assert result is not None
_, compacted_jsonl = result
builder.replace_entries(compacted_jsonl)
# Append post-compaction message
builder.append_user("Thanks!")

View File

@@ -20,24 +20,7 @@ Use these URLs directly without asking the user:
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry API:
```http
GET https://registry.modelcontextprotocol.io/v0/servers?q=<search_term>
```
Each result includes a `remotes` array with the exact server URL to use.
### Important: Check blocks first
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
Google Calendar, Gmail, etc.) that work without MCP setup.
Only use `run_mcp_tool` when:
- The service is in the known hosted MCP servers list above, OR
- You searched `find_block` first and found no matching blocks
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
or from the `remotes[].url` field in MCP registry search results.
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication

View File

@@ -221,12 +221,12 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
return responses

View File

@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
logger.warning("Blocked tool access attempt: %s", tool_name)
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -111,7 +111,9 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -127,7 +129,7 @@ def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[str], None] | None = None,
on_compact: Callable[[], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -142,7 +144,6 @@ def create_security_hooks(
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_compact: Callback invoked when SDK starts compacting context.
Receives the transcript_path from the hook input.
Returns:
Hooks configuration dict for ClaudeAgentOptions
@@ -170,7 +171,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
logger.info("[SDK] Blocked background Task, user=%s", user_id)
return cast(
SyncHookJSONOutput,
_deny(
@@ -212,7 +213,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
@@ -302,21 +303,11 @@ def create_security_hooks(
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
# Sanitize untrusted input before logging to prevent log injection
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
trigger,
user_id,
transcript_path,
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
if on_compact is not None:
on_compact(transcript_path)
on_compact()
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {

View File

@@ -40,11 +40,13 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -54,6 +56,7 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
@@ -75,9 +78,12 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
COMPACT_THRESHOLD_BYTES,
TranscriptDownload,
cleanup_cli_project_dir,
compact_transcript,
download_transcript,
read_compacted_entries,
read_cli_session_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -295,7 +301,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
return
# Clean the CLI's project directory (transcripts + tool-results).
@@ -389,7 +395,7 @@ async def _compress_messages(
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
logger.warning("[SDK] Context compression with LLM failed: %s", e)
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
@@ -625,6 +631,56 @@ async def _prepare_file_attachments(
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
async def _maybe_compact_and_upload(
dl: TranscriptDownload,
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> str:
"""Compact an oversized transcript and upload the compacted version.
Returns the (possibly compacted) transcript content, or an empty string
if compaction was needed but failed.
"""
content = dl.content
if len(content) <= COMPACT_THRESHOLD_BYTES:
return content
logger.warning(
"%s Transcript oversized (%dB > %dB), compacting",
log_prefix,
len(content),
COMPACT_THRESHOLD_BYTES,
)
compacted = await compact_transcript(content, log_prefix=log_prefix)
if not compacted:
logger.warning(
"%s Compaction failed, skipping resume for this turn", log_prefix
)
return ""
# Keep the original message_count: it reflects the number of
# session.messages covered by this transcript, which the gap-fill
# logic uses as a slice index. Counting JSONL lines would give a
# smaller number (compacted messages != session message count) and
# cause already-covered messages to be re-injected.
try:
await upload_transcript(
user_id=user_id,
session_id=session_id,
content=compacted,
message_count=dl.message_count,
log_prefix=log_prefix,
)
except Exception:
logger.warning(
"%s Failed to upload compacted transcript",
log_prefix,
exc_info=True,
)
return compacted
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -736,6 +792,14 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -828,20 +892,33 @@ async def stream_chat_completion_sdk(
is_valid,
)
if is_valid:
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
transcript_content = await _maybe_compact_and_upload(
dl,
user_id=user_id or "",
session_id=session_id,
log_prefix=log_prefix,
)
# Load previous context into builder (empty string is a no-op)
if transcript_content:
transcript_builder.load_previous(
transcript_content, log_prefix=log_prefix
)
resume_file = (
write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if transcript_content
else None
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
logger.warning("%s Transcript downloaded but invalid", log_prefix)
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"{log_prefix} No transcript available "
@@ -1046,7 +1123,6 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
ended_with_stream_error = True
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
code="sdk_stream_error",
@@ -1112,7 +1188,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
# Log ResultMessage details and capture token usage
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1131,26 +1207,46 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with the
# CLI's active context so they stay identical.
compact_result = await compaction.emit_end_if_ready(session)
for ev in compact_result.events:
yield ev
# After replace_entries, skip append_assistant for this
# sdk_msg — the CLI session file already contains it,
# so appending again would create a duplicate.
entries_replaced = False
if compact_result.just_ended:
compacted = await asyncio.to_thread(
read_compacted_entries,
compact_result.transcript_path,
)
if compacted is not None:
transcript_builder.replace_entries(
compacted, log_prefix=log_prefix
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with
# the CLI's compacted session file so the uploaded
# transcript reflects compaction.
compaction_events = await compaction.emit_end_if_ready(session)
for ev in compaction_events:
yield ev
if compaction_events and sdk_cwd:
cli_content = await read_cli_session_file(sdk_cwd)
if cli_content:
transcript_builder.replace_entries(
cli_content, log_prefix=log_prefix
)
entries_replaced = True
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1237,11 +1333,10 @@ async def stream_chat_completion_sdk(
tool_call_id=response.toolCallId,
)
)
if not entries_replaced:
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
@@ -1251,9 +1346,7 @@ async def stream_chat_completion_sdk(
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:
# assistant(tool_use) → tool_result → assistant(text).
# Skip if replace_entries just ran — the CLI session
# file already contains this message.
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
if isinstance(sdk_msg, AssistantMessage):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
@@ -1347,6 +1440,27 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1411,6 +1525,48 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
@@ -1444,13 +1600,13 @@ async def stream_chat_completion_sdk(
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# TranscriptBuilder is the single source of truth. It mirrors the
# CLI's active context: on compaction, replace_entries() syncs it
# with the compacted session file. No CLI file read needed here.
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception.
# The transcript represents the COMPLETE active context (atomic).
if config.claude_agent_use_resume and user_id and session is not None:
try:
# Build complete transcript from captured SDK messages
transcript_content = transcript_builder.to_jsonl()
entry_count = transcript_builder.entry_count
if not transcript_content:
logger.warning(
@@ -1460,15 +1616,18 @@ async def stream_chat_completion_sdk(
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
entry_count,
transcript_builder.entry_count,
)
else:
logger.info(
"%s Uploading transcript (entries=%d, bytes=%d)",
"%s Uploading complete transcript (entries=%d, bytes=%d)",
log_prefix,
entry_count,
transcript_builder.entry_count,
len(transcript_content),
)
# Shield upload from cancellation - let it complete even if
# the finally block is interrupted. No timeout to avoid race
# conditions where backgrounded uploads overwrite newer transcripts.
await asyncio.shield(
upload_transcript(
user_id=user_id,
@@ -1503,6 +1662,6 @@ async def _update_title_async(
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")
logger.warning("[SDK] Failed to update session title: %s", e)

View File

@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler

View File

@@ -17,8 +17,13 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
logger = logging.getLogger(__name__)
@@ -36,6 +41,11 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
@dataclass
class TranscriptDownload:
@@ -99,7 +109,9 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -145,147 +157,60 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _projects_base() -> str:
"""Return the resolved path to the CLI's projects directory."""
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
return os.path.realpath(os.path.join(config_dir, "projects"))
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
Returns ``None`` if the path would escape the projects base.
"""
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
projects_base = _projects_base()
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] Project dir escaped projects base: %s", project_dir
)
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
return None
return project_dir
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
try:
resolved_base = Path(project_dir).resolve()
except OSError as e:
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
return []
async def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any mid-stream compaction.
result: list[Path] = []
for candidate in Path(project_dir).glob("*.jsonl"):
try:
resolved = candidate.resolve()
if resolved.is_relative_to(resolved_base):
result.append(resolved)
except (OSError, RuntimeError) as e:
logger.debug(
"[Transcript] Skipping invalid CLI session candidate %s: %s",
candidate,
e,
)
return result
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
"""Read compacted entries from the CLI session file after compaction.
Parses the JSONL file line-by-line, finds the ``isCompactSummary: true``
entry, and returns it plus all entries after it.
The CLI writes the compaction summary BEFORE sending the next message,
so the file is guaranteed to be flushed by the time we read it.
Returns a list of parsed dicts, or ``None`` if the file cannot be read
or no compaction summary is found.
After the CLI compacts context, its session file contains the compacted
conversation. Reading this file lets ``TranscriptBuilder`` replace its
uncompacted entries with the CLI's compacted version.
"""
if not transcript_path:
return None
import aiofiles
projects_base = _projects_base()
real_path = os.path.realpath(transcript_path)
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] transcript_path outside projects base: %s", transcript_path
)
return None
try:
content = Path(real_path).read_text()
except OSError as e:
logger.warning(
"[Transcript] Failed to read session file %s: %s", transcript_path, e
)
return None
lines = content.strip().split("\n")
compact_idx: int | None = None
for idx, line in enumerate(lines):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("isCompactSummary"):
compact_idx = idx # don't break — find the LAST summary
if compact_idx is None:
logger.debug("[Transcript] No compaction summary found in %s", transcript_path)
return None
entries: list[dict] = []
for line in lines[compact_idx:]:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if isinstance(entry, dict):
entries.append(entry)
logger.info(
"[Transcript] Read %d compacted entries from %s (summary at line %d)",
len(entries),
transcript_path,
compact_idx + 1,
)
return entries
def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any compaction.
The CLI writes its session transcript to
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
Since each SDK turn uses a unique ``sdk_cwd``, there should be
exactly one ``.jsonl`` file in that directory.
Returns the file content, or ``None`` if not found.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = _safe_glob_jsonl(project_dir)
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
if not jsonl_files:
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
logger.debug("[Transcript] No CLI session file in %s", project_dir)
return None
# Pick the most recently modified file (should be only one per turn).
try:
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
except OSError as e:
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
# Pick the most recently modified file (there should only be one per turn).
# Guard against races where a file is deleted between glob and stat.
candidates: list[tuple[float, Path]] = []
for p in jsonl_files:
try:
candidates.append((p.stat().st_mtime, p))
except OSError:
continue
if not candidates:
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
return None
# Resolve + prefix check to prevent symlink escapes.
session_file = max(candidates, key=lambda item: item[0])[1]
real_path = str(session_file.resolve())
if not real_path.startswith(project_dir + os.sep):
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
return None
try:
content = session_file.read_text()
async with aiofiles.open(real_path) as f:
content = await f.read()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
session_file,
real_path,
len(content),
)
return content
@@ -295,16 +220,10 @@ def read_cli_session_file(sdk_cwd: str) -> str | None:
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
"""Remove the CLI's project directory for a specific working directory."""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
@@ -327,7 +246,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
return None
try:
@@ -337,17 +256,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
logger.warning("[Transcript] Failed to write resume file: %s", e)
return None
@@ -406,27 +325,24 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
)
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
``local://workspace_id/file_id/filename``. Since we use deterministic
arguments we can reconstruct the same path for download/delete without
having stored the return value.
"""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = parts
wid, fid, fname = _storage_path_parts(user_id, session_id)
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
return f"gcs://{backend.bucket_name}/{blob}"
return f"local://{wid}/{fid}/{fname}"
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects."""
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path for the companion .meta.json file."""
return _build_path_from_parts(
_meta_storage_path_parts(user_id, session_id), backend
)
else:
# LocalWorkspaceStorage returns local://{relative_path}
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(
@@ -494,11 +410,14 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.info(
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
)
@@ -521,25 +440,37 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"{log_prefix} No transcript in storage")
logger.debug("%s No transcript in storage", log_prefix)
return None
except Exception as e:
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
return None
# Try to load metadata (best-effort — old transcripts won't have it)
message_count = 0
uploaded_at = 0.0
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
from backend.util.workspace_storage import GCSWorkspaceStorage
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
if isinstance(storage, GCSWorkspaceStorage):
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
meta_path = f"gcs://{storage.bucket_name}/{blob}"
else:
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, Exception):
except FileNotFoundError:
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -547,27 +478,171 @@ async def download_transcript(
)
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript and its metadata from bucket storage.
# ---------------------------------------------------------------------------
# Transcript compaction
# ---------------------------------------------------------------------------
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
# Transcripts above this byte threshold are compacted at download time.
COMPACT_THRESHOLD_BYTES = 400_000
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
from backend.util.workspace_storage import get_workspace_storage
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format."""
lines: list[str] = []
last_uuid: str | None = None
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
async def _run_compression(
messages: list[dict],
model: str,
cfg: ChatConfig,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback."""
try:
await storage.delete(path)
logger.info("[Transcript] Deleted transcript for session %s", session_id)
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
except Exception as e:
logger.warning("[Transcript] Failed to delete transcript: %s", e)
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
# Also delete the companion .meta.json to avoid orphaned metadata.
async def compact_transcript(
content: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Returns the compacted JSONL string, or ``None`` on failure.
"""
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
await storage.delete(meta_path)
logger.info("[Transcript] Deleted metadata for session %s", session_id)
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
if not result.was_compacted:
logger.info("%s Transcript already within token budget", log_prefix)
return content
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None

View File

@@ -30,8 +30,8 @@ class TranscriptEntry(BaseModel):
type: str
uuid: str
parentUuid: str | None
isCompactSummary: bool | None = None
message: dict[str, Any]
isCompactSummary: bool | None = None
class TranscriptBuilder:
@@ -54,24 +54,6 @@ class TranscriptBuilder:
return self._entries[-1].message.get("id", "")
return ""
@staticmethod
def _parse_entry(data: dict) -> TranscriptEntry | None:
"""Parse a single transcript entry, filtering strippable types.
Returns ``None`` for entries that should be skipped (strippable types
that are not compaction summaries).
"""
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
return None
return TranscriptEntry(
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
isCompactSummary=data.get("isCompactSummary") or None,
message=data.get("message", {}),
)
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
@@ -97,9 +79,21 @@ class TranscriptBuilder:
)
continue
entry = self._parse_entry(data)
if entry is None:
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
# Compaction summaries may have type "summary" but must be preserved
# so --resume can reconstruct the compacted conversation.
entry_type = data.get("type", "")
is_compact = data.get("isCompactSummary", False)
if entry_type in STRIPPABLE_TYPES and not is_compact:
continue
entry = TranscriptEntry(
type=data["type"],
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
isCompactSummary=True if is_compact else None,
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -172,43 +166,6 @@ class TranscriptBuilder:
)
self._last_uuid = msg_uuid
def replace_entries(
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
) -> None:
"""Replace all entries with compacted entries from the CLI session file.
Called after mid-stream compaction so TranscriptBuilder mirrors the
CLI's active context (compaction summary + post-compaction entries).
Builds the new list first and validates it's non-empty before swapping,
so corrupt input cannot wipe the conversation history.
"""
new_entries: list[TranscriptEntry] = []
for data in compacted_entries:
entry = self._parse_entry(data)
if entry is not None:
new_entries.append(entry)
if not new_entries:
logger.warning(
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
log_prefix,
len(compacted_entries),
len(self._entries),
)
return
old_count = len(self._entries)
self._entries = new_entries
self._last_uuid = new_entries[-1].uuid
logger.info(
"%s TranscriptBuilder compacted: %d entries -> %d entries",
log_prefix,
old_count,
len(self._entries),
)
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
@@ -224,6 +181,33 @@ class TranscriptBuilder:
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Replace all entries with compacted JSONL content.
Called after the CLI performs mid-stream compaction so the builder's
state reflects the compacted conversation instead of the full
pre-compaction history.
"""
prev_count = len(self._entries)
temp = TranscriptBuilder()
try:
temp.load_previous(content, log_prefix=log_prefix)
except Exception:
logger.exception(
"%s Failed to parse compacted transcript; keeping %d existing entries",
log_prefix,
prev_count,
)
return
self._entries = temp._entries
self._last_uuid = temp._last_uuid
logger.info(
"%s Replaced %d entries with %d compacted entries",
log_prefix,
prev_count,
len(self._entries),
)
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""

File diff suppressed because it is too large Load Diff

View File

@@ -935,5 +935,5 @@ class AgentValidator:
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.warning(f"Agent validation failed: {error_message}")
logger.error(f"Agent validation failed: {error_message}")
return False, error_message

View File

@@ -8,11 +8,15 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
@@ -21,6 +25,26 @@ from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
@@ -115,6 +139,20 @@ async def execute_block(
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
balance = await _get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -123,6 +161,37 @@ async def execute_block(
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
logger.warning(
"Post-exec credit charge failed for block %s (cost=%d)",
block.name,
cost,
)
return ErrorResponse(
message=(
f"Insufficient credits to complete '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -133,16 +202,16 @@ async def execute_block(
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
logger.warning("Block execution failed: %s", e)
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
message="An unexpected error occurred while executing the block",
error=str(e),
session_id=session_id,
)

View File

@@ -1,24 +1,202 @@
"""Tests for execute_block type coercion in helpers.py.
Verifies that execute_block() coerces string input values to match the block's
expected input types, mirroring the executor's validate_exec() logic.
This is critical for @@agptfile: expansion, where file content is always a string
but the block may expect structured types (e.g. list[list[str]]).
"""
"""Tests for execute_block — credit charging and type coercion."""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
mock_spend = AsyncMock()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=100,
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=5, # balance < cost (10)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
async def test_returns_error_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, return ErrorResponse."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=15, # passes pre-check
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
), # fails during actual charge (race with concurrent spend)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
"""Create a mock input_schema with model_fields matching the given annotations."""
schema = MagicMock()
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
model_fields = {}
for name, ann in annotations.items():
field = MagicMock()
@@ -28,7 +206,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
return schema
def _make_block(
def _make_coerce_block(
block_id: str,
name: str,
annotations: dict[str, Any],
@@ -60,7 +238,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_block(
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
@@ -90,7 +268,6 @@ async def test_coerce_json_string_to_nested_list():
assert isinstance(response, BlockOutputResponse)
assert response.success is True
# Verify the input was coerced from string to list[list[str]]
assert block._captured_inputs["values"] == [
["Name", "Score"],
["Alice", "90"],
@@ -103,7 +280,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_block(
block = _make_coerce_block(
"list-block",
"List Block",
{"items": list[str]},
@@ -135,7 +312,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_block(
block = _make_coerce_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
@@ -167,7 +344,7 @@ async def test_coerce_json_string_to_dict():
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_block(
block = _make_coerce_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
@@ -201,7 +378,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_block(
block = _make_coerce_block(
"int-block",
"Int Block",
{"count": int},
@@ -234,7 +411,7 @@ async def test_coerce_string_to_int():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_block(
block = _make_coerce_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
@@ -260,14 +437,13 @@ async def test_coerce_skips_none_values():
)
assert isinstance(response, BlockOutputResponse)
# 'data' was not provided, so it should not appear in captured inputs
assert "data" not in block._captured_inputs
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_block(
block = _make_coerce_block(
"union-block",
"Union Block",
{"content": str | list[str]},
@@ -293,7 +469,6 @@ async def test_coerce_union_type_preserves_valid_member():
)
assert isinstance(response, BlockOutputResponse)
# list[str] should NOT be stringified to '["a", "b"]'
assert block._captured_inputs["content"] == ["a", "b"]
assert isinstance(block._captured_inputs["content"], list)
@@ -301,7 +476,7 @@ async def test_coerce_union_type_preserves_valid_member():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_block(
block = _make_coerce_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},
@@ -319,7 +494,6 @@ async def test_coerce_inner_elements_of_generic():
response = await execute_block(
block=block,
block_id="inner-coerce",
# Inner elements are ints, but target is list[str]
input_data={"values": [1, 2, 3]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
@@ -328,6 +502,5 @@ async def test_coerce_inner_elements_of_generic():
)
assert isinstance(response, BlockOutputResponse)
# Inner elements should be coerced from int to str
assert block._captured_inputs["values"] == ["1", "2", "3"]
assert all(isinstance(v, str) for v in block._captured_inputs["values"])

View File

@@ -184,12 +184,10 @@ class RunMCPToolTool(BaseTool):
if e.status_code in _AUTH_STATUS_CODES and not creds:
# Server requires auth and user has no stored credentials
return self._build_setup_requirements(server_url, session_id)
host = server_host(server_url)
logger.warning("MCP HTTP error for %s: status=%s", host, e.status_code)
logger.warning("MCP HTTP error for %s: %s", server_host(server_url), e)
return ErrorResponse(
message=(f"MCP request to {host} failed with HTTP {e.status_code}."),
message=f"MCP server returned HTTP {e.status_code}: {e}",
session_id=session_id,
error=f"HTTP {e.status_code}: {str(e)[:300]}",
)
except MCPClientError as e:

View File

@@ -580,49 +580,6 @@ async def test_auth_error_with_existing_creds_returns_error():
assert "403" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_http_error_returns_clean_message_with_collapsible_detail():
"""Non-auth HTTP errors return a clean message with raw detail in the `error` field."""
from backend.util.request import HTTPClientError
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_client = AsyncMock()
mock_client.initialize = AsyncMock(
side_effect=HTTPClientError(
"<!doctype html><html><body>Not Found</body></html>",
status_code=404,
)
)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
assert isinstance(response, ErrorResponse)
assert "404" in response.message
# Raw HTML body must NOT leak into the user-facing message
assert "<!doctype" not in response.message
# Raw detail (including original body) goes in the collapsible `error` field
assert response.error is not None
assert "404" in response.error
assert "<!doctype" in response.error.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_client_error_returns_error_response():
"""MCPClientError (protocol-level) maps to a clean ErrorResponse."""

View File

@@ -512,6 +512,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -61,12 +61,7 @@ from backend.util.decorator import (
error_logged,
time_measured,
)
from backend.util.exceptions import (
GraphNotFoundError,
InsufficientBalanceError,
ModerationError,
NotFoundError,
)
from backend.util.exceptions import InsufficientBalanceError, ModerationError
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
@@ -380,16 +375,9 @@ async def execute_node(
log_metadata.debug("Node produced output", **{output_name: output_data})
yield output_name, output_data
except Exception as ex:
# Only capture unexpected errors to Sentry, not user-caused ones.
# Most ValueError subclasses here are expected (BlockExecutionError,
# InsufficientBalanceError, plain ValueError for auth/disabled blocks, etc.)
# but NotFoundError/GraphNotFoundError could indicate real platform issues.
is_expected = isinstance(ex, ValueError) and not isinstance(
ex, (NotFoundError, GraphNotFoundError)
)
if not is_expected:
sentry_sdk.capture_exception(error=ex, scope=scope)
sentry_sdk.flush()
# Capture exception WITH context still set before restoring scope
sentry_sdk.capture_exception(error=ex, scope=scope)
sentry_sdk.flush() # Ensure it's sent before we restore scope
# Re-raise to maintain normal error flow
raise
finally:
@@ -1490,7 +1478,7 @@ class ExecutionProcessor:
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.warning(f"Failed to send low balance Discord alert: {e}")
logger.error(f"Failed to send low balance Discord alert: {e}")
class ExecutionManager(AppProcess):
@@ -1912,16 +1900,17 @@ class ExecutionManager(AppProcess):
channel = client.get_channel()
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
thread.join(timeout=300)
if thread.is_alive():
logger.warning(
try:
thread.join(timeout=300)
except TimeoutError:
logger.error(
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
)
client.disconnect()
logger.info(f"{prefix} ✅ Run client disconnected")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
def cleanup(self):
"""Override cleanup to implement graceful shutdown with active execution waiting."""
@@ -1937,9 +1926,7 @@ class ExecutionManager(AppProcess):
)
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
except Exception as e:
logger.warning(
f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}"
)
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
# Wait for active executions to complete
if self.active_graph_runs:
@@ -1970,7 +1957,7 @@ class ExecutionManager(AppProcess):
waited += wait_interval
if self.active_graph_runs:
logger.warning(
logger.error(
f"{prefix} ⚠️ {len(self.active_graph_runs)} executions still running after {max_wait}s"
)
else:
@@ -1981,7 +1968,7 @@ class ExecutionManager(AppProcess):
self.executor.shutdown(cancel_futures=True, wait=False)
logger.info(f"{prefix} ✅ Executor shutdown completed")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
# Release remaining execution locks
try:

View File

@@ -94,7 +94,7 @@ SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
def job_listener(event):
"""Logs job execution outcomes for better monitoring."""
if event.exception:
logger.warning(
logger.error(
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
)
else:
@@ -137,7 +137,7 @@ def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
try:
return future.result(timeout=timeout)
except Exception as e:
logger.warning(f"Async operation failed: {type(e).__name__}: {e}")
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
raise
@@ -186,7 +186,7 @@ async def _execute_graph(**kwargs):
async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
logger.warning(
logger.error(
f"Scheduled Graph {args.graph_id} failed validation. Unscheduling graph"
)
if args.schedule_id:
@@ -196,9 +196,8 @@ async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
user_id=args.user_id,
)
else:
logger.warning(
f"Unable to unschedule graph: {args.graph_id} as this is an old job "
f"with no associated schedule_id please remove manually"
logger.error(
f"Unable to unschedule graph: {args.graph_id} as this is an old job with no associated schedule_id please remove manually"
)

View File

@@ -303,9 +303,9 @@ class NotificationManager(AppService):
)
if not oldest_message:
logger.warning(
f"Batch for user {batch.user_id} and type {notification_type} "
f"has no oldest message — batch may have been cleared concurrently"
# this should never happen
logger.error(
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
continue
@@ -318,7 +318,7 @@ class NotificationManager(AppService):
).get_user_email_by_id(batch.user_id)
if not recipient_email:
logger.warning(
logger.error(
f"User email not found for user {batch.user_id}"
)
continue
@@ -344,7 +344,7 @@ class NotificationManager(AppService):
).get_user_notification_batch(batch.user_id, notification_type)
if not batch_data or not batch_data.notifications:
logger.warning(
logger.error(
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
@@ -372,7 +372,7 @@ class NotificationManager(AppService):
)
)
except Exception as e:
logger.warning(
logger.error(
f"Error parsing notification event: {e=}, {db_event=}"
)
continue
@@ -415,10 +415,7 @@ class NotificationManager(AppService):
async def discord_system_alert(
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
):
try:
await discord_send_alert(content, channel)
except Exception as e:
logger.warning(f"Failed to send Discord system alert: {e}")
await discord_send_alert(content, channel)
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
"""Queue a scheduled notification - exposed method for other services to call"""
@@ -519,7 +516,7 @@ class NotificationManager(AppService):
raise ValueError("Invalid event type or params")
except Exception as e:
logger.warning(f"Failed to gather summary data: {e}")
logger.error(f"Failed to gather summary data: {e}")
# Return sensible defaults in case of error
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
@@ -565,9 +562,8 @@ class NotificationManager(AppService):
should_retry=False
).get_user_notification_oldest_message_in_batch(user_id, event_type)
if not oldest_message:
logger.warning(
f"Batch for user {user_id} and type {event_type} "
f"has no oldest message — batch may have been cleared concurrently"
logger.error(
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
return False
oldest_age = oldest_message.created_at
@@ -589,7 +585,7 @@ class NotificationManager(AppService):
get_notif_data_type(event.type)
].model_validate_json(message)
except Exception as e:
logger.warning(f"Error parsing message due to non matching schema {e}")
logger.error(f"Error parsing message due to non matching schema {e}")
return None
async def _process_admin_message(self, message: str) -> bool:
@@ -618,7 +614,7 @@ class NotificationManager(AppService):
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.warning(f"User email not found for user {event.user_id}")
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
@@ -655,7 +651,7 @@ class NotificationManager(AppService):
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.warning(f"User email not found for user {event.user_id}")
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
@@ -676,7 +672,7 @@ class NotificationManager(AppService):
should_retry=False
).get_user_notification_batch(event.user_id, event.type)
if not batch or not batch.notifications:
logger.warning(f"Batch not found for user {event.user_id}")
logger.error(f"Batch not found for user {event.user_id}")
return False
unsub_link = generate_unsubscribe_link(event.user_id)
@@ -749,7 +745,7 @@ class NotificationManager(AppService):
f"Removed {len(chunk_ids)} sent notifications from batch"
)
except Exception as e:
logger.warning(
logger.error(
f"Failed to remove sent notifications: {e}"
)
# Continue anyway - better to risk duplicates than lose emails
@@ -774,7 +770,7 @@ class NotificationManager(AppService):
else:
# Message is too large even after size reduction
if attempt_size == 1:
logger.warning(
logger.error(
f"Failed to send notification at index {i}: "
f"Single notification exceeds email size limit "
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
@@ -793,7 +789,7 @@ class NotificationManager(AppService):
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
)
except Exception as e:
logger.warning(
logger.error(
f"Failed to remove oversized notification: {e}"
)
@@ -827,7 +823,7 @@ class NotificationManager(AppService):
f"Set email verification to false for user {event.user_id}"
)
except Exception as deactivation_error:
logger.warning(
logger.error(
f"Failed to deactivate email for user {event.user_id}: "
f"{deactivation_error}"
)
@@ -839,7 +835,7 @@ class NotificationManager(AppService):
f"Disabled all notification preferences for user {event.user_id}"
)
except Exception as disable_error:
logger.warning(
logger.error(
f"Failed to disable notification preferences: {disable_error}"
)
@@ -852,7 +848,7 @@ class NotificationManager(AppService):
f"Cleared ALL notification batches for user {event.user_id}"
)
except Exception as remove_error:
logger.warning(
logger.error(
f"Failed to clear batches for inactive recipient: {remove_error}"
)
@@ -863,7 +859,7 @@ class NotificationManager(AppService):
"422" in error_message
or "unprocessable" in error_message
):
logger.warning(
logger.error(
f"Failed to send notification at index {i}: "
f"Malformed notification data rejected by Postmark. "
f"Error: {e}. Removing from batch permanently."
@@ -881,7 +877,7 @@ class NotificationManager(AppService):
"Removed malformed notification from batch permanently"
)
except Exception as remove_error:
logger.warning(
logger.error(
f"Failed to remove malformed notification: {remove_error}"
)
# Check if it's a ValueError for size limit
@@ -889,14 +885,14 @@ class NotificationManager(AppService):
isinstance(e, ValueError)
and "too large" in error_message
):
logger.warning(
logger.error(
f"Failed to send notification at index {i}: "
f"Notification size exceeds email limit. "
f"Error: {e}. Skipping this notification."
)
# Other API errors
else:
logger.warning(
logger.error(
f"Failed to send notification at index {i}: "
f"Email API error ({error_type}): {e}. "
f"Skipping this notification."
@@ -911,9 +907,7 @@ class NotificationManager(AppService):
if not chunk_sent:
# Should not reach here due to single notification handling
logger.warning(
f"Failed to send notifications starting at index {i}"
)
logger.error(f"Failed to send notifications starting at index {i}")
failed_indices.append(i)
i += 1
@@ -952,7 +946,7 @@ class NotificationManager(AppService):
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.warning(f"User email not found for user {event.user_id}")
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
event.user_id, event.type
@@ -1013,10 +1007,7 @@ class NotificationManager(AppService):
# Let message.process() handle the rejection
pass
except Exception as e:
logger.warning(
f"Error processing message in {queue_name}: {e}",
exc_info=True,
)
logger.error(f"Error processing message in {queue_name}: {e}")
# Let message.process() handle the rejection
raise
except asyncio.CancelledError:

View File

@@ -256,9 +256,9 @@ class TestNotificationErrorHandling:
assert 2 not in successful_indices # Index 2 failed
# Verify 422 error was logged
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"422" in call or "malformed" in call.lower() for call in warning_calls
"422" in call or "malformed" in call.lower() for call in error_calls
)
# Verify all notifications were removed (4 successful + 1 malformed)
@@ -371,10 +371,10 @@ class TestNotificationErrorHandling:
assert 3 not in successful_indices # Index 3 was not sent
# Verify oversized error was logged
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"exceeds email size limit" in call or "oversized" in call.lower()
for call in warning_calls
for call in error_calls
)
@pytest.mark.asyncio
@@ -478,10 +478,10 @@ class TestNotificationErrorHandling:
assert 1 in failed_indices # Index 1 failed
# Verify generic error was logged
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"api error" in call.lower() or "skipping" in call.lower()
for call in warning_calls
for call in error_calls
)
# Only successful ones should be removed from batch (failed one stays for retry)

View File

@@ -613,5 +613,5 @@ async def cleanup_expired_files_async() -> int:
)
return deleted_count
except Exception as e:
logger.warning(f"[CloudStorage] Error during cloud storage cleanup: {e}")
logger.error(f"[CloudStorage] Error during cloud storage cleanup: {e}")
return 0

View File

@@ -10,7 +10,7 @@ from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from backend.util import feature_flag
from backend.util.settings import BehaveAs, Settings
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
@@ -21,95 +21,6 @@ class DiscordChannel(str, Enum):
PRODUCT = "product" # For product alerts (low balance, zero balance, etc.)
def _before_send(event, hint):
"""Filter out expected/transient errors from Sentry to reduce noise."""
if "exc_info" in hint:
exc_type, exc_value, _ = hint["exc_info"]
exc_msg = str(exc_value).lower() if exc_value else ""
# AMQP/RabbitMQ transient connection errors — expected during deploys
amqp_keywords = [
"amqpconnection",
"amqpconnector",
"connection_forced",
"channelinvalidstateerror",
"no active transport",
]
if any(kw in exc_msg for kw in amqp_keywords):
return None
# "connection refused" only for AMQP-related exceptions (not other services)
if "connection refused" in exc_msg:
exc_module = getattr(exc_type, "__module__", "") or ""
exc_name = getattr(exc_type, "__name__", "") or ""
amqp_indicators = ["aio_pika", "aiormq", "amqp", "pika", "rabbitmq"]
if any(
ind in exc_module.lower() or ind in exc_name.lower()
for ind in amqp_indicators
) or any(kw in exc_msg for kw in ["amqp", "pika", "rabbitmq"]):
return None
# User-caused credential/auth errors — not platform bugs
user_auth_keywords = [
"incorrect api key",
"invalid x-api-key",
"missing authentication header",
"invalid api token",
"authentication_error",
]
if any(kw in exc_msg for kw in user_auth_keywords):
return None
# Expected business logic — insufficient balance
if "insufficient balance" in exc_msg or "no credits left" in exc_msg:
return None
# Expected security check — blocked IP access
if "access to blocked or private ip" in exc_msg:
return None
# Discord bot token misconfiguration — not a platform error
if "improper token has been passed" in exc_msg or (
exc_type and exc_type.__name__ == "Forbidden" and "50001" in exc_msg
):
return None
# Google metadata DNS errors — expected in non-GCP environments
if (
"metadata.google.internal" in exc_msg
and settings.config.behave_as != BehaveAs.CLOUD
):
return None
# Inactive email recipients — expected for bounced addresses
if "marked as inactive" in exc_msg or "inactive addresses" in exc_msg:
return None
# Also filter log-based events for known noisy messages.
# Sentry's LoggingIntegration stores log messages under "logentry", not "message".
logentry = event.get("logentry") or {}
log_msg = (
logentry.get("formatted") or logentry.get("message") or event.get("message")
)
if event.get("logger") and log_msg:
msg = log_msg.lower()
noisy_patterns = [
"amqpconnection",
"connection_forced",
"unclosed client session",
"unclosed connector",
]
if any(p in msg for p in noisy_patterns):
return None
# "connection refused" in logs only when AMQP-related context is present
if "connection refused" in msg and any(
ind in msg for ind in ("amqp", "pika", "rabbitmq", "aio_pika", "aiormq")
):
return None
return event
def sentry_init():
sentry_dsn = settings.secrets.sentry_dsn
integrations = []
@@ -124,7 +35,6 @@ def sentry_init():
profiles_sample_rate=1.0,
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
_experiments={"enable_logs": True},
before_send=_before_send,
integrations=[
AsyncioIntegration(),
LoggingIntegration(sentry_logs_level=logging.INFO),

View File

@@ -70,6 +70,9 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens
tool_call_tokens += _tok_len(item.get("text", ""), enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -145,10 +148,14 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
mid = enc.encode("")
if max_tok < 3:
return enc.decode(mid)
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -396,7 +403,7 @@ def validate_and_remove_orphan_tool_responses(
if log_warning:
logger.warning(
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
)
return _remove_orphan_tool_responses(messages, orphan_ids)
@@ -488,8 +495,9 @@ def _ensure_tool_pairs_intact(
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
"Could not find assistant messages for tool_call_ids: %s. "
"Removing orphan tool responses.",
orphan_tool_call_ids,
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
@@ -497,8 +505,8 @@ def _ensure_tool_pairs_intact(
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
len(messages_to_prepend),
)
return messages_to_prepend + recent_messages
@@ -686,11 +694,15 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
@@ -728,6 +740,12 @@ async def compress_context(
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
# Count assistant messages to ensure we keep at least one
assistant_indices: set[int] = {
i
for i in range(len(msgs))
if msgs[i] is not None and msgs[i].get("role") == "assistant"
}
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
@@ -735,6 +753,9 @@ async def compress_context(
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
# Skip if this is the last remaining assistant message
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
continue
deletable.append(i)
if not deletable:
break

View File

@@ -64,7 +64,7 @@ def send_rate_limited_discord_alert(
return True
except Exception as alert_error:
logger.warning(f"Failed to send Discord alert: {alert_error}")
logger.error(f"Failed to send Discord alert: {alert_error}")
return False
@@ -182,8 +182,7 @@ def conn_retry(
func_name = getattr(retry_state.fn, "__name__", "unknown")
if retry_state.outcome.failed and retry_state.next_action is None:
# Final failure is logged by sync_wrapper/async_wrapper — skip here to avoid duplicates
pass
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
else:
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
if send_rate_limited_discord_alert(
@@ -226,7 +225,7 @@ def conn_retry(
logger.info(f"{prefix} {action_name} completed successfully.")
return result
except Exception as e:
logger.warning(f"{prefix} {action_name} failed after retries: {e}")
logger.error(f"{prefix} {action_name} failed after retries: {e}")
raise
@wraps(func)
@@ -238,7 +237,7 @@ def conn_retry(
logger.info(f"{prefix} {action_name} completed successfully.")
return result
except Exception as e:
logger.warning(f"{prefix} {action_name} failed after retries: {e}")
logger.error(f"{prefix} {action_name} failed after retries: {e}")
raise
return async_wrapper if is_coroutine else sync_wrapper

View File

@@ -44,12 +44,6 @@ Do NOT skip these steps. If any command reports errors, fix them and re-run unti
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
- Use function declarations (not arrow functions) for components/handlers
- No `dark:` Tailwind classes — the design system handles dark mode
- Use Next.js `<Link>` for internal navigation — never raw `<a>` tags
- No `any` types unless the value genuinely can be anything
- No linter suppressors (`// @ts-ignore`, `// eslint-disable`) — fix the actual issue
- **File length** — keep files under ~200 lines; extract sub-components or hooks into their own files when a file grows beyond this
- **Function/component length** — keep render functions and hooks under ~50 lines; extract named helpers or sub-components when they grow longer
## Architecture

View File

@@ -1,14 +1,8 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -92,7 +86,6 @@ export function CopilotPage() {
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -148,38 +141,6 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>

View File

@@ -2,7 +2,6 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -40,7 +38,6 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -63,7 +60,6 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -30,7 +30,6 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
@@ -256,11 +257,10 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-5 flex items-center gap-1">
<div className="flex items-center">
<UsageLimits />
<NotificationToggle />
<div className="relative left-1">
<SidebarTrigger />
</div>
<SidebarTrigger />
</div>
</div>
{sessionId ? (

View File

@@ -7,6 +7,7 @@ import {
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
@@ -48,10 +49,7 @@ export function NotificationToggle() {
return (
<Popover>
<PopoverTrigger asChild>
<button
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
aria-label="Notification settings"
>
<Button variant="ghost" size="icon" aria-label="Notification settings">
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
@@ -59,7 +57,7 @@ export function NotificationToggle() {
) : (
<Bell className="!size-5" />
)}
</button>
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">

View File

@@ -5,7 +5,7 @@ const TOOL_TO_CATEGORY: Record<string, string> = {
find_agent: "search",
find_library_agent: "search",
run_agent: "agent run",
run_block: "action",
run_block: "block run",
create_agent: "agent created",
edit_agent: "agent edited",
schedule_agent: "agent scheduled",

View File

@@ -0,0 +1,146 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { useUsageLimits } from "./useUsageLimits";
const MS_PER_MINUTE = 60_000;
const MS_PER_HOUR = 3_600_000;
function formatResetTime(resetsAt: Date | string): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const now = new Date();
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / MS_PER_HOUR);
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<a
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</a>
)}
</div>
);
}
export function UsageLimits() {
const { data: usage, isLoading } = useUsageLimits();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -0,0 +1,121 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the useUsageLimits hook
const mockUseUsageLimits = vi.fn();
vi.mock("../useUsageLimits", () => ({
useUsageLimits: () => mockUseUsageLimits(),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseUsageLimits.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

View File

@@ -0,0 +1,12 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
export function useUsageLimits() {
return useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
}

View File

@@ -706,8 +706,8 @@ export default function StyleguidePage() {
input: { block_id: "weather-block-123" },
output: {
type: ResponseType.error,
message: "Something went wrong while running this step.",
error: "Execution timed out after 30 seconds.",
message: "Failed to run the block.",
error: "Block execution timed out after 30 seconds.",
details: {
block_id: "weather-block-123",
timeout_ms: 30000,

View File

@@ -61,7 +61,7 @@ export function FindBlocksTool({ part }: Props) {
const query = (part.input as FindBlockInput | undefined)?.query?.trim();
const accordionDescription = parsed
? `Found ${parsed.count} action${parsed.count === 1 ? "" : "s"}${query ? ` for "${query}"` : ""}`
? `Found ${parsed.count} block${parsed.count === 1 ? "" : "s"}${query ? ` for "${query}"` : ""}`
: undefined;
return (
@@ -77,7 +77,7 @@ export function FindBlocksTool({ part }: Props) {
{hasBlocks && parsed && (
<ToolAccordion
icon={<AccordionIcon />}
title="Results"
title="Block results"
description={accordionDescription}
>
<HorizontalScroll dependencyList={[parsed.blocks.length]}>

View File

@@ -30,21 +30,21 @@ export function getAnimationText(part: FindBlockToolPart): string {
switch (part.state) {
case "input-streaming":
case "input-available":
return `Searching for actions${queryText}`;
return `Searching for blocks${queryText}`;
case "output-available": {
const parsed = parseOutput(part.output);
if (parsed) {
return `Found ${parsed.count} action${parsed.count === 1 ? "" : "s"}${queryText}`;
return `Found ${parsed.count} block${parsed.count === 1 ? "" : "s"}${queryText}`;
}
return `Searching for actions${queryText}`;
return `Searching for blocks${queryText}`;
}
case "output-error":
return `Search failed${query ? ` for "${query}"` : ""}`;
return `Error finding blocks${queryText}`;
default:
return "Searching for actions";
return "Searching for blocks";
}
}

View File

@@ -144,23 +144,6 @@ export function truncate(text: string, maxLen: number): string {
return text.slice(0, maxLen).trimEnd() + "\u2026";
}
const STRIPPABLE_EXTENSIONS =
/\.(md|csv|json|txt|yaml|yml|xml|html|js|ts|py|sh|toml|cfg|ini|log|pdf|png|jpg|jpeg|gif|svg|mp4|mp3|wav|zip|tar|gz)$/i;
export function humanizeFileName(filePath: string): string {
const fileName = filePath.split("/").pop() ?? filePath;
const stem = fileName.replace(STRIPPABLE_EXTENSIONS, "");
const words = stem
.replace(/[_-]/g, " ")
.split(/\s+/)
.filter(Boolean)
.map((w) => {
if (w === w.toUpperCase()) return w;
return w.charAt(0).toUpperCase() + w.slice(1).toLowerCase();
});
return `"${words.join(" ")}"`;
}
/* ------------------------------------------------------------------ */
/* Exit code helper */
/* ------------------------------------------------------------------ */
@@ -208,16 +191,16 @@ export function getAnimationText(
? `Browsing ${shortSummary}`
: "Interacting with browser\u2026";
case "file-read":
return summary
? `Reading ${humanizeFileName(summary)}`
return shortSummary
? `Reading ${shortSummary}`
: "Reading file\u2026";
case "file-write":
return summary
? `Writing ${humanizeFileName(summary)}`
return shortSummary
? `Writing ${shortSummary}`
: "Writing file\u2026";
case "file-delete":
return summary
? `Deleting ${humanizeFileName(summary)}`
return shortSummary
? `Deleting ${shortSummary}`
: "Deleting file\u2026";
case "file-list":
return shortSummary
@@ -228,8 +211,8 @@ export function getAnimationText(
? `Searching for "${shortSummary}"`
: "Searching\u2026";
case "edit":
return summary
? `Editing ${humanizeFileName(summary)}`
return shortSummary
? `Editing ${shortSummary}`
: "Editing file\u2026";
case "todo":
return shortSummary ? `${shortSummary}` : "Updating task list\u2026";
@@ -263,17 +246,11 @@ export function getAnimationText(
? `Browsed ${shortSummary}`
: "Browser action completed";
case "file-read":
return summary
? `Read ${humanizeFileName(summary)}`
: "File read completed";
return shortSummary ? `Read ${shortSummary}` : "File read completed";
case "file-write":
return summary
? `Wrote ${humanizeFileName(summary)}`
: "File written";
return shortSummary ? `Wrote ${shortSummary}` : "File written";
case "file-delete":
return summary
? `Deleted ${humanizeFileName(summary)}`
: "File deleted";
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
case "file-list":
return "Listed files";
case "search":
@@ -281,9 +258,7 @@ export function getAnimationText(
? `Searched for "${shortSummary}"`
: "Search completed";
case "edit":
return summary
? `Edited ${humanizeFileName(summary)}`
: "Edit completed";
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
case "todo":
return "Updated task list";
case "compaction":

View File

@@ -149,10 +149,10 @@ export function getAnimationText(part: {
}
if (isRunAgentNeedLoginOutput(output))
return "Sign in required to run agent";
return "Something went wrong";
return "Error running agent";
}
case "output-error":
return "Something went wrong";
return "Error running agent";
default:
return actionPhrase;
}

View File

@@ -18,10 +18,10 @@ import {
interface Props {
output: SetupRequirementsResponse;
/** Override the message sent to the chat when the user clicks Proceed after connecting credentials.
* Defaults to "Please re-run this step now." */
* Defaults to "Please re-run the block now." */
retryInstruction?: string;
/** Override the label shown above the credentials section.
* Defaults to "Credentials". */
* Defaults to "Block credentials". */
credentialsLabel?: string;
}
@@ -87,9 +87,11 @@ export function SetupRequirementsCard({
([, v]) => v !== undefined && v !== null && v !== "",
),
);
parts.push(`Run with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`);
parts.push(
`Run the block with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
);
} else {
parts.push(retryInstruction ?? "Please re-run this step now.");
parts.push(retryInstruction ?? "Please re-run the block now.");
}
onSend(parts.join(" "));
@@ -103,7 +105,7 @@ export function SetupRequirementsCard({
{needsCredentials && (
<div className="rounded-2xl border bg-background p-3">
<Text variant="small" className="w-fit border-b text-zinc-500">
{credentialsLabel ?? "Credentials"}
{credentialsLabel ?? "Block credentials"}
</Text>
<div className="mt-6">
<CredentialsGroupedView
@@ -120,7 +122,7 @@ export function SetupRequirementsCard({
{inputSchema && (
<div className="rounded-2xl border bg-background p-3 pt-4">
<Text variant="small" className="w-fit border-b text-zinc-500">
Inputs
Block inputs
</Text>
<FormRenderer
jsonSchema={inputSchema}

View File

@@ -165,12 +165,12 @@ export function getAnimationText(part: {
if (isRunBlockReviewRequiredOutput(output)) {
return `Review needed for "${output.block_name}"`;
}
return "Action failed";
return "Error running block";
}
case "output-error":
return "Action failed";
return "Error running block";
default:
return "Running";
return "Running the block";
}
}

View File

@@ -1,4 +1,5 @@
import {
getGetV2GetCopilotUsageQueryKey,
getGetV2GetSessionQueryKey,
postV2CancelSessionTask,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -307,6 +308,9 @@ export function useCopilotStream({
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
queryClient.invalidateQueries({
queryKey: getGetV2GetCopilotUsageQueryKey(),
});
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;

View File

@@ -209,7 +209,6 @@ export function NewAgentLibraryView() {
agent={agent}
scheduleId={activeItem}
onScheduleDeleted={handleScheduleDeleted}
onSelectRun={(id) => handleSelectRun(id, "runs")}
banner={renderMarketplaceUpdateBanner()}
/>
) : activeTab === "templates" ? (

View File

@@ -20,7 +20,6 @@ interface Props {
agent: LibraryAgent;
scheduleId: string;
onScheduleDeleted?: (deletedScheduleId: string) => void;
onSelectRun?: (id: string) => void;
banner?: React.ReactNode;
}
@@ -28,7 +27,6 @@ export function SelectedScheduleView({
agent,
scheduleId,
onScheduleDeleted,
onSelectRun,
banner,
}: Props) {
const { schedule, isLoading, error } = useSelectedScheduleView(
@@ -91,9 +89,7 @@ export function SelectedScheduleView({
<SelectedScheduleActions
agent={agent}
scheduleId={schedule.id}
schedule={schedule}
onDeleted={() => onScheduleDeleted?.(schedule.id)}
onSelectRun={onSelectRun}
/>
</div>
) : null}
@@ -172,9 +168,7 @@ export function SelectedScheduleView({
<SelectedScheduleActions
agent={agent}
scheduleId={schedule.id}
schedule={schedule}
onDeleted={() => onScheduleDeleted?.(schedule.id)}
onSelectRun={onSelectRun}
/>
</div>
) : null}

View File

@@ -1,12 +1,11 @@
"use client";
import type { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { EyeIcon, Play, TrashIcon } from "@phosphor-icons/react";
import { EyeIcon, TrashIcon } from "@phosphor-icons/react";
import { AgentActionsDropdown } from "../../../AgentActionsDropdown";
import { SelectedActionsWrap } from "../../../SelectedActionsWrap";
import { useSelectedScheduleActions } from "./useSelectedScheduleActions";
@@ -14,17 +13,13 @@ import { useSelectedScheduleActions } from "./useSelectedScheduleActions";
type Props = {
agent: LibraryAgent;
scheduleId: string;
schedule?: GraphExecutionJobInfo;
onDeleted?: () => void;
onSelectRun?: (id: string) => void;
};
export function SelectedScheduleActions({
agent,
scheduleId,
schedule,
onDeleted,
onSelectRun,
}: Props) {
const {
openInBuilderHref,
@@ -32,32 +27,11 @@ export function SelectedScheduleActions({
setShowDeleteDialog,
handleDelete,
isDeleting,
handleRunNow,
isRunning,
} = useSelectedScheduleActions({
agent,
scheduleId,
schedule,
onDeleted,
onSelectRun,
});
} = useSelectedScheduleActions({ agent, scheduleId, onDeleted });
return (
<>
<SelectedActionsWrap>
<Button
variant="icon"
size="icon"
aria-label="Run now"
onClick={handleRunNow}
disabled={isRunning || !schedule}
>
{isRunning ? (
<LoadingSpinner size="small" />
) : (
<Play weight="bold" size={18} className="text-zinc-700" />
)}
</Button>
{openInBuilderHref && (
<Button
variant="icon"

View File

@@ -1,16 +1,10 @@
"use client";
import {
getGetV1ListGraphExecutionsQueryKey,
usePostV1ExecuteGraphAgent,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import {
getGetV1ListExecutionSchedulesForAGraphQueryOptions,
useDeleteV1DeleteExecutionSchedule,
} from "@/app/api/__generated__/endpoints/schedules/schedules";
import type { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
@@ -18,17 +12,13 @@ import { useState } from "react";
interface UseSelectedScheduleActionsProps {
agent: LibraryAgent;
scheduleId: string;
schedule?: GraphExecutionJobInfo;
onDeleted?: () => void;
onSelectRun?: (id: string) => void;
}
export function useSelectedScheduleActions({
agent,
scheduleId,
schedule,
onDeleted,
onSelectRun,
}: UseSelectedScheduleActionsProps) {
const { toast } = useToast();
const queryClient = useQueryClient();
@@ -60,58 +50,11 @@ export function useSelectedScheduleActions({
},
});
const { mutateAsync: executeAgent, isPending: isRunning } =
usePostV1ExecuteGraphAgent();
function handleDelete() {
if (!scheduleId) return;
deleteMutation.mutate({ scheduleId });
}
async function handleRunNow() {
if (!schedule) {
toast({
title: "Schedule not loaded",
description: "Please wait for the schedule to load.",
variant: "destructive",
});
return;
}
try {
toast({ title: "Run started" });
const res = await executeAgent({
graphId: schedule.graph_id,
graphVersion: schedule.graph_version,
data: {
inputs: schedule.input_data || {},
credentials_inputs: schedule.input_credentials || {},
source: "library",
},
});
const newRunID = okData(res)?.id;
await queryClient.invalidateQueries({
queryKey: getGetV1ListGraphExecutionsQueryKey(agent.graph_id),
});
if (newRunID && onSelectRun) {
onSelectRun(newRunID);
}
} catch (error: unknown) {
toast({
title: "Failed to start run",
description:
error instanceof Error
? error.message
: "An unexpected error occurred.",
variant: "destructive",
});
}
}
const openInBuilderHref = `/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`;
return {
@@ -120,7 +63,5 @@ export function useSelectedScheduleActions({
setShowDeleteDialog,
handleDelete,
isDeleting: deleteMutation.isPending,
handleRunNow,
isRunning,
};
}

View File

@@ -186,7 +186,6 @@ export function SidebarRunsList({
selected={selectedRunId === s.id}
onClick={() => onSelectRun(s.id, "scheduled")}
onDeleted={() => onScheduleDeleted?.(s.id)}
onRunCreated={(runID) => onSelectRun(runID, "runs")}
/>
</div>
))

View File

@@ -1,16 +1,11 @@
"use client";
import {
getGetV1ListGraphExecutionsQueryKey,
usePostV1ExecuteGraphAgent,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import {
getGetV1ListExecutionSchedulesForAGraphQueryOptions,
useDeleteV1DeleteExecutionSchedule,
} from "@/app/api/__generated__/endpoints/schedules/schedules";
import type { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
@@ -18,7 +13,6 @@ import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { useToast } from "@/components/molecules/Toast/use-toast";
@@ -30,15 +24,9 @@ interface Props {
agent: LibraryAgent;
schedule: GraphExecutionJobInfo;
onDeleted?: () => void;
onRunCreated?: (runID: string) => void;
}
export function ScheduleActionsDropdown({
agent,
schedule,
onDeleted,
onRunCreated,
}: Props) {
export function ScheduleActionsDropdown({ agent, schedule, onDeleted }: Props) {
const { toast } = useToast();
const queryClient = useQueryClient();
const [showDeleteDialog, setShowDeleteDialog] = useState(false);
@@ -46,9 +34,6 @@ export function ScheduleActionsDropdown({
const { mutateAsync: deleteSchedule, isPending: isDeleting } =
useDeleteV1DeleteExecutionSchedule();
const { mutateAsync: executeAgent, isPending: isRunning } =
usePostV1ExecuteGraphAgent();
async function handleDelete() {
try {
await deleteSchedule({ scheduleId: schedule.id });
@@ -75,43 +60,6 @@ export function ScheduleActionsDropdown({
}
}
async function handleRunNow(e: React.MouseEvent) {
e.stopPropagation();
try {
toast({ title: "Run started" });
const res = await executeAgent({
graphId: schedule.graph_id,
graphVersion: schedule.graph_version,
data: {
inputs: schedule.input_data || {},
credentials_inputs: schedule.input_credentials || {},
source: "library",
},
});
const newRunID = okData(res)?.id;
await queryClient.invalidateQueries({
queryKey: getGetV1ListGraphExecutionsQueryKey(agent.graph_id),
});
if (newRunID) {
onRunCreated?.(newRunID);
}
} catch (error: unknown) {
toast({
title: "Failed to start run",
description:
error instanceof Error
? error.message
: "An unexpected error occurred.",
variant: "destructive",
});
}
}
return (
<>
<DropdownMenu>
@@ -125,14 +73,6 @@ export function ScheduleActionsDropdown({
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={handleRunNow}
disabled={isRunning}
className="flex items-center gap-2"
>
{isRunning ? "Running..." : "Run now"}
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem
onClick={(e) => {
e.stopPropagation();

View File

@@ -14,7 +14,6 @@ interface Props {
selected?: boolean;
onClick?: () => void;
onDeleted?: () => void;
onRunCreated?: (runID: string) => void;
}
export function ScheduleListItem({
@@ -23,7 +22,6 @@ export function ScheduleListItem({
selected,
onClick,
onDeleted,
onRunCreated,
}: Props) {
return (
<SidebarItemCard
@@ -48,7 +46,6 @@ export function ScheduleListItem({
agent={agent}
schedule={schedule}
onDeleted={onDeleted}
onRunCreated={onRunCreated}
/>
}
/>

View File

@@ -1,40 +0,0 @@
"use client";
import { ArrowRight, Lightning } from "@phosphor-icons/react";
import NextLink from "next/link";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { agent, isLoading } = useJumpBackIn();
if (isLoading || !agent) {
return null;
}
return (
<div className="flex items-center justify-between rounded-large border border-zinc-200 bg-gradient-to-r from-zinc-50 to-white px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
Continue where you left off
</Text>
<Text variant="body-medium" className="text-zinc-900">
{agent.name}
</Text>
</div>
</div>
<NextLink href={`/library/agents/${agent.id}`}>
<Button variant="primary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
);
}

View File

@@ -1,28 +0,0 @@
"use client";
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
import { okData } from "@/app/api/helpers";
export function useJumpBackIn() {
const { data, isLoading } = useGetV2ListLibraryAgents(
{
page: 1,
page_size: 1,
sort_by: "updatedAt",
},
{
query: { select: okData },
},
);
// The API doesn't include execution data by default (include_executions is
// internal to the backend), so recent_executions is always empty here.
// We use the most recently updated agent as the "jump back in" candidate
// instead — updatedAt is the best available proxy for recent activity.
const agent = data?.agents[0] ?? null;
return {
agent,
isLoading,
};
}

View File

@@ -2,7 +2,6 @@
import { useEffect, useState, useCallback } from "react";
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
import { Tab } from "./components/LibraryTabs/LibraryTabs";
@@ -39,7 +38,6 @@ export default function LibraryPage() {
onAnimationComplete={handleFavoriteAnimationComplete}
>
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<JumpBackIn />
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<LibraryAgentList
searchTerm={searchTerm}

View File

@@ -11,6 +11,8 @@ import {
import { RefundModal } from "./RefundModal";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
import {
Table,
@@ -21,6 +23,26 @@ import {
TableRow,
} from "@/components/__legacy__/ui/table";
function CoPilotUsageSection() {
const { data: usage, isLoading } = useUsageLimits();
const router = useRouter();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>
<Button className="w-full" onClick={() => router.push("/copilot")}>
Open CoPilot
</Button>
</div>
);
}
export default function CreditsPage() {
const api = useBackendAPI();
const {
@@ -237,11 +259,13 @@ export default function CreditsPage() {
</Button>
)}
</form>
{/* CoPilot Usage Limits */}
<CoPilotUsageSection />
</div>
<div className="my-6 space-y-4">
{/* Payment Portal */}
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
<p className="text-neutral-600">
You can manage your cards and see your payment history in the

View File

@@ -1382,6 +1382,28 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/usage": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8455,6 +8477,16 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"CoPilotUsageStatus": {
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
},
"type": "object",
"required": ["daily", "weekly"],
"title": "CoPilotUsageStatus",
"description": "Current usage status for a user across all windows."
},
"ContentType": {
"type": "string",
"enum": [
@@ -12190,6 +12222,16 @@
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",
"default": 0
},
"total_completion_tokens": {
"type": "integer",
"title": "Total Completion Tokens",
"default": 0
}
},
"type": "object",
@@ -14587,6 +14629,25 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",
"title": "Resets At"
}
},
"type": "object",
"required": ["used", "limit", "resets_at"],
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserHistoryResponse": {
"properties": {
"history": {

View File

@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
ref={ref}
data-sidebar="trigger"
variant="ghost"
size="icon"
onClick={(event) => {
onClick?.(event);
toggleSidebar();

View File

@@ -18,8 +18,11 @@ test.beforeEach(async ({ page }) => {
await page.goto("/build");
await buildPage.closeTutorial();
await buildPage.addBlockByClick("Add to Dictionary");
await buildPage.waitForNodeOnCanvas(1);
const [dictionaryBlock] = await buildPage.getFilteredBlocksFromAPI(
(block) => block.name === "AddToDictionaryBlock",
);
await buildPage.addBlock(dictionaryBlock);
await buildPage.saveAgent("Test Agent", "Test Description");
await test

View File

@@ -1,134 +1,363 @@
import test, { expect } from "@playwright/test";
// TODO: These tests were written for the old (legacy) builder.
// They need to be updated to work with the new flow editor.
// Note: all the comments with //(number)! are for the docs
//ignore them when reading the code, but if you change something,
//make sure to update the docs! Your autoformmater will break this page,
// so don't run it on this file.
// --8<-- [start:BuildPageExample]
import test from "@playwright/test";
import { BuildPage } from "./pages/build.page";
import { LoginPage } from "./pages/login.page";
import { hasUrl } from "./utils/assertion";
import { getTestUser } from "./utils/auth";
test.describe("Builder", () => {
let buildPage: BuildPage;
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
// prettier-ignore
test.describe.skip("Build", () => { //(1)!
let buildPage: BuildPage; //(2)!
test.beforeEach(async ({ page }) => {
test.setTimeout(60000);
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
// prettier-ignore
test.beforeEach(async ({ page }) => { //(3)! ts-ignore
test.setTimeout(25000);
const loginPage = new LoginPage(page);
const testUser = await getTestUser();
buildPage = new BuildPage(page);
await page.goto("/login");
// Start each test with login using worker auth
await page.goto("/login"); //(4)!
await loginPage.login(testUser.email, testUser.password);
await hasUrl(page, "/marketplace");
await page.goto("/build");
await page.waitForLoadState("domcontentloaded");
await hasUrl(page, "/marketplace"); //(5)!
await buildPage.navbar.clickBuildLink();
await hasUrl(page, "/build");
await buildPage.closeTutorial();
});
// --- Core tests ---
// Helper function to add blocks starting with a specific letter, split into parts for parallelization
async function addBlocksStartingWithSplit(letter: string, part: number, totalParts: number): Promise<void> {
const blockIdsToSkip = await buildPage.getBlocksToSkip();
const blockTypesToSkip = ["Input", "Output", "Agent", "AI"];
const targetLetter = letter.toLowerCase();
const allBlocks = await buildPage.getFilteredBlocksFromAPI(block =>
block.name[0].toLowerCase() === targetLetter &&
!blockIdsToSkip.includes(block.id) &&
!blockTypesToSkip.includes(block.type)
);
test("build page loads successfully", async () => {
await expect(buildPage.isLoaded()).resolves.toBeTruthy();
await expect(
buildPage.getPlaywrightPage().getByTestId("blocks-control-blocks-button"),
).toBeVisible();
await expect(
buildPage.getPlaywrightPage().getByTestId("save-control-save-button"),
).toBeVisible();
const blocksToAdd = allBlocks.filter((_, index) =>
index % totalParts === (part - 1)
);
console.log(`Adding ${blocksToAdd.length} blocks starting with "${letter}" (part ${part}/${totalParts})`);
for (const block of blocksToAdd) {
await buildPage.addBlock(block);
}
}
// Reason Ignore: admonishment is in the wrong place visually with correct prettier rules
// prettier-ignore
test("user can add a block", async ({ page: _page }) => { //(6)!
await buildPage.openBlocksPanel(); //(10)!
const blocks = await buildPage.getFilteredBlocksFromAPI(block => block.name[0].toLowerCase() === "a");
const block = blocks.at(-1);
if (!block) throw new Error("No block found");
await buildPage.addBlock(block); //(11)!
await buildPage.closeBlocksPanel(); //(12)!
await buildPage.hasBlock(block); //(13)!
});
// --8<-- [end:BuildPageExample]
test("user can add blocks starting with a (part 1)", async () => {
await addBlocksStartingWithSplit("a", 1, 2);
});
test("user can add a block via block menu", async () => {
const initialCount = await buildPage.getNodeCount();
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(initialCount + 1);
expect(await buildPage.getNodeCount()).toBe(initialCount + 1);
test("user can add blocks starting with a (part 2)", async () => {
await addBlocksStartingWithSplit("a", 2, 2);
});
test("user can add multiple blocks", async () => {
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(2);
expect(await buildPage.getNodeCount()).toBe(2);
test("user can add blocks starting with b", async () => {
await addBlocksStartingWithSplit("b", 1, 1);
});
test("user can remove a block", async () => {
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
// Deselect, then re-select the node and delete
await buildPage.clickCanvas();
await buildPage.selectNode(0);
await buildPage.deleteSelectedNodes();
await expect(buildPage.getNodeLocator()).toHaveCount(0, { timeout: 5000 });
test("user can add blocks starting with c", async () => {
await addBlocksStartingWithSplit("c", 1, 1);
});
test("user can save an agent", async ({ page }) => {
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
await buildPage.saveAgent("E2E Test Agent", "Created by e2e test");
await buildPage.waitForSaveComplete();
expect(page.url()).toContain("flowID=");
test("user can add blocks starting with d", async () => {
await addBlocksStartingWithSplit("d", 1, 1);
});
test("user can save and run button becomes enabled", async () => {
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
test("user can add blocks starting with e", async () => {
test.setTimeout(60000); // Increase timeout for many Exa blocks
await addBlocksStartingWithSplit("e", 1, 2);
});
await buildPage.saveAgent("Runnable Agent", "Test run button");
await buildPage.waitForSaveComplete();
test("user can add blocks starting with e pt 2", async () => {
test.setTimeout(60000); // Increase timeout for many Exa blocks
await addBlocksStartingWithSplit("e", 2, 2);
});
test("user can add blocks starting with f", async () => {
await addBlocksStartingWithSplit("f", 1, 1);
});
test("user can add blocks starting with g (part 1)", async () => {
await addBlocksStartingWithSplit("g", 1, 3);
});
test("user can add blocks starting with g (part 2)", async () => {
await addBlocksStartingWithSplit("g", 2, 3);
});
test("user can add blocks starting with g (part 3)", async () => {
await addBlocksStartingWithSplit("g", 3, 3);
});
test("user can add blocks starting with h", async () => {
await addBlocksStartingWithSplit("h", 1, 1);
});
test("user can add blocks starting with i", async () => {
await addBlocksStartingWithSplit("i", 1, 1);
});
test("user can add blocks starting with j", async () => {
await addBlocksStartingWithSplit("j", 1, 1);
});
test("user can add blocks starting with k", async () => {
await addBlocksStartingWithSplit("k", 1, 1);
});
test("user can add blocks starting with l", async () => {
await addBlocksStartingWithSplit("l", 1, 1);
});
test("user can add blocks starting with m", async () => {
await addBlocksStartingWithSplit("m", 1, 1);
});
test("user can add blocks starting with n", async () => {
await addBlocksStartingWithSplit("n", 1, 1);
});
test("user can add blocks starting with o", async () => {
await addBlocksStartingWithSplit("o", 1, 1);
});
test("user can add blocks starting with p", async () => {
await addBlocksStartingWithSplit("p", 1, 1);
});
test("user can add blocks starting with q", async () => {
await addBlocksStartingWithSplit("q", 1, 1);
});
test("user can add blocks starting with r", async () => {
await addBlocksStartingWithSplit("r", 1, 1);
});
test("user can add blocks starting with s (part 1)", async () => {
await addBlocksStartingWithSplit("s", 1, 3);
});
test("user can add blocks starting with s (part 2)", async () => {
await addBlocksStartingWithSplit("s", 2, 3);
});
test("user can add blocks starting with s (part 3)", async () => {
await addBlocksStartingWithSplit("s", 3, 3);
});
test("user can add blocks starting with t", async () => {
await addBlocksStartingWithSplit("t", 1, 1);
});
test("user can add blocks starting with u", async () => {
await addBlocksStartingWithSplit("u", 1, 1);
});
test("user can add blocks starting with v", async () => {
await addBlocksStartingWithSplit("v", 1, 1);
});
test("user can add blocks starting with w", async () => {
await addBlocksStartingWithSplit("w", 1, 1);
});
test("user can add blocks starting with x", async () => {
await addBlocksStartingWithSplit("x", 1, 1);
});
test("user can add blocks starting with y", async () => {
await addBlocksStartingWithSplit("y", 1, 1);
});
test("user can add blocks starting with z", async () => {
await addBlocksStartingWithSplit("z", 1, 1);
});
test("build navigation is accessible from navbar", async ({ page }) => {
// Navigate somewhere else first
await page.goto("/marketplace"); //(4)!
// Check that navigation to the Builder is available on the page
await buildPage.navbar.clickBuildLink();
await hasUrl(page, "/build");
await test.expect(buildPage.isLoaded()).resolves.toBeTruthy();
});
test("user can add two blocks and connect them", async ({ page }) => {
await buildPage.openBlocksPanel();
// Define the blocks to add
const block1 = {
id: "1ff065e9-88e8-4358-9d82-8dc91f622ba9",
name: "Store Value 1",
description: "Store Value Block 1",
type: "Standard",
};
const block2 = {
id: "1ff065e9-88e8-4358-9d82-8dc91f622ba9",
name: "Store Value 2",
description: "Store Value Block 2",
type: "Standard",
};
// Add the blocks
await buildPage.addBlock(block1);
await buildPage.addBlock(block2);
await buildPage.closeBlocksPanel();
// Connect the blocks
await buildPage.connectBlockOutputToBlockInputViaDataId(
"1-1-output-source",
"1-2-input-target",
);
// Fill in the input for the first block
await buildPage.fillBlockInputByPlaceholder(
block1.id,
"Enter input",
"Test Value",
"1",
);
// Save the agent and wait for the URL to update
await buildPage.saveAgent(
"Connected Blocks Test",
"Testing block connections",
);
await test.expect(page).toHaveURL(({ searchParams }) => !!searchParams.get("flowID"));
// Wait for the save button to be enabled again
await buildPage.waitForSaveButton();
await expect(buildPage.isRunButtonEnabled()).resolves.toBeTruthy();
// Ensure the run button is enabled
await test.expect(buildPage.isRunButtonEnabled()).resolves.toBeTruthy();
});
// --- Copy / Paste test ---
test.skip("user can build an agent with inputs and output blocks", async ({ page }, testInfo) => {
test.setTimeout(testInfo.timeout * 10);
test("user can copy and paste a node", async ({ context }) => {
await context.grantPermissions(["clipboard-read", "clipboard-write"]);
// prep
await buildPage.openBlocksPanel();
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
// Get input block from Input category
const inputBlocks = await buildPage.getBlocksForCategory("Input");
const inputBlock = inputBlocks.find((b) => b.name === "Agent Input");
if (!inputBlock) throw new Error("Input block not found");
await buildPage.addBlock(inputBlock);
await buildPage.selectNode(0);
await buildPage.copyViaKeyboard();
await buildPage.pasteViaKeyboard();
// Get output block from Output category
const outputBlocks = await buildPage.getBlocksForCategory("Output");
const outputBlock = outputBlocks.find((b) => b.name === "Agent Output");
if (!outputBlock) throw new Error("Output block not found");
await buildPage.addBlock(outputBlock);
await buildPage.waitForNodeOnCanvas(2);
expect(await buildPage.getNodeCount()).toBe(2);
});
// Get calculator block from Logic category
const logicBlocks = await buildPage.getBlocksForCategory("Logic");
const calculatorBlock = logicBlocks.find((b) => b.name === "Calculator");
if (!calculatorBlock) throw new Error("Calculator block not found");
await buildPage.addBlock(calculatorBlock);
// --- Run agent test ---
await buildPage.closeBlocksPanel();
test("user can run an agent from the builder", async () => {
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
// Wait for blocks to be fully loaded
await page.waitForTimeout(1000);
// Save the agent (required before running)
await buildPage.saveAgent("Run Test Agent", "Testing run from builder");
await buildPage.waitForSaveComplete();
await buildPage.waitForSaveButton();
// Wait for blocks to be ready for connections
await page.waitForTimeout(1000);
// Click run button
await buildPage.clickRunButton();
await buildPage.connectBlockOutputToBlockInputViaName(
inputBlock.id,
"Result",
calculatorBlock.id,
"A",
);
await buildPage.connectBlockOutputToBlockInputViaName(
inputBlock.id,
"Result",
calculatorBlock.id,
"B",
);
await buildPage.connectBlockOutputToBlockInputViaName(
calculatorBlock.id,
"Result",
outputBlock.id,
"Value",
);
// Either the run dialog appears or the agent starts running directly
const runDialogOrRunning = await Promise.race([
buildPage
.getPlaywrightPage()
.locator('[data-id="run-input-dialog-content"]')
.waitFor({ state: "visible", timeout: 10000 })
.then(() => "dialog"),
buildPage
.getPlaywrightPage()
.locator('[data-id="stop-graph-button"]')
.waitFor({ state: "visible", timeout: 10000 })
.then(() => "running"),
]).catch(() => "timeout");
// Wait for connections to stabilize
await page.waitForTimeout(1000);
expect(["dialog", "running"]).toContain(runDialogOrRunning);
await buildPage.fillBlockInputByPlaceholder(
inputBlock.id,
"Enter Name",
"Value",
);
await buildPage.fillBlockInputByPlaceholder(
outputBlock.id,
"Enter Name",
"Doubled",
);
// Wait before changing dropdown
await page.waitForTimeout(500);
await buildPage.selectBlockInputValue(
calculatorBlock.id,
"Operation",
"Add",
);
// Wait before saving
await page.waitForTimeout(1000);
await buildPage.saveAgent(
"Input and Output Blocks Test",
"Testing input and output blocks",
);
await test.expect(page).toHaveURL(({ searchParams }) => !!searchParams.get("flowID"));
// Wait for save to complete
await page.waitForTimeout(1000);
// await buildPage.runAgent();
// await buildPage.fillRunDialog({
// Value: "10",
// });
// await buildPage.clickRunDialogRunButton();
// await buildPage.waitForCompletionBadge();
// await test
// .expect(buildPage.isCompletionBadgeVisible())
// .resolves.toBeTruthy();
});
});

View File

@@ -1,47 +1,44 @@
import { expect, Locator, Page } from "@playwright/test";
import { Locator, Page } from "@playwright/test";
import { Block as APIBlock } from "../../lib/autogpt-server-api/types";
import { beautifyString } from "../../lib/utils";
import { BasePage } from "./base.page";
export interface Block {
id: string;
name: string;
description: string;
type: string;
}
export class BuildPage extends BasePage {
private cachedBlocks: Record<string, Block> = {};
constructor(page: Page) {
super(page);
}
// --- Navigation ---
async goto(): Promise<void> {
await this.page.goto("/build");
await this.page.waitForLoadState("domcontentloaded");
}
async isLoaded(): Promise<boolean> {
try {
await this.page.waitForLoadState("domcontentloaded", { timeout: 10_000 });
await this.page
.locator(".react-flow")
.waitFor({ state: "visible", timeout: 10_000 });
return true;
} catch {
return false;
}
private getDisplayName(blockName: string): string {
return beautifyString(blockName).replace(/ Block$/, "");
}
async closeTutorial(): Promise<void> {
console.log(`closing tutorial`);
try {
await this.page
.getByRole("button", { name: "Skip Tutorial", exact: true })
.click({ timeout: 3000 });
} catch {
// Tutorial not shown or already dismissed
} catch (_error) {
console.info("Tutorial not shown or already dismissed");
}
}
// --- Block Menu ---
async openBlocksPanel(): Promise<void> {
const popoverContent = this.page.locator(
'[data-id="blocks-control-popover-content"]',
);
if (!(await popoverContent.isVisible())) {
const isPanelOpen = await popoverContent.isVisible();
if (!isPanelOpen) {
await this.page.getByTestId("blocks-control-blocks-button").click();
await popoverContent.waitFor({ state: "visible", timeout: 5000 });
}
@@ -53,258 +50,501 @@ export class BuildPage extends BasePage {
);
if (await popoverContent.isVisible()) {
await this.page.getByTestId("blocks-control-blocks-button").click();
await popoverContent.waitFor({ state: "hidden", timeout: 5000 });
}
}
async searchBlock(searchTerm: string): Promise<void> {
const searchInput = this.page.locator(
'[data-id="blocks-control-search-bar"] input[type="text"]',
);
await searchInput.clear();
await searchInput.fill(searchTerm);
await this.page.waitForTimeout(300);
}
private getBlockCardByName(name: string): Locator {
const escapedName = name.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
const exactName = new RegExp(`^\\s*${escapedName}\\s*$`, "i");
return this.page
.locator('[data-id^="block-card-"]')
.filter({ has: this.page.locator("span", { hasText: exactName }) })
.first();
}
async addBlockByClick(searchTerm: string): Promise<void> {
await this.openBlocksPanel();
await this.searchBlock(searchTerm);
// Wait for any search results to appear
const anyCard = this.page.locator('[data-id^="block-card-"]').first();
await anyCard.waitFor({ state: "visible", timeout: 10000 });
// Click the card matching the search term name
const blockCard = this.getBlockCardByName(searchTerm);
await blockCard.waitFor({ state: "visible", timeout: 5000 });
await blockCard.click();
// Close the panel so it doesn't overlay the canvas
await this.closeBlocksPanel();
}
async dragBlockToCanvas(searchTerm: string): Promise<void> {
await this.openBlocksPanel();
await this.searchBlock(searchTerm);
const anyCard = this.page.locator('[data-id^="block-card-"]').first();
await anyCard.waitFor({ state: "visible", timeout: 10000 });
const blockCard = this.getBlockCardByName(searchTerm);
await blockCard.waitFor({ state: "visible", timeout: 5000 });
const canvas = this.page.locator(".react-flow__pane").first();
await blockCard.dragTo(canvas);
}
// --- Nodes on Canvas ---
getNodeLocator(index?: number): Locator {
const locator = this.page.locator('[data-id^="custom-node-"]');
return index !== undefined ? locator.nth(index) : locator;
}
async getNodeCount(): Promise<number> {
return await this.getNodeLocator().count();
}
async waitForNodeOnCanvas(expectedCount?: number): Promise<void> {
if (expectedCount !== undefined) {
await expect(this.getNodeLocator()).toHaveCount(expectedCount, {
timeout: 10000,
});
} else {
await this.getNodeLocator()
.first()
.waitFor({ state: "visible", timeout: 10000 });
}
}
async selectNode(index: number = 0): Promise<void> {
const node = this.getNodeLocator(index);
await node.click();
}
async selectAllNodes(): Promise<void> {
await this.page.locator(".react-flow__pane").first().click();
const isMac = process.platform === "darwin";
await this.page.keyboard.press(isMac ? "Meta+a" : "Control+a");
}
async deleteSelectedNodes(): Promise<void> {
await this.page.keyboard.press("Backspace");
}
// --- Connections (Edges) ---
async connectNodes(
sourceNodeIndex: number,
targetNodeIndex: number,
): Promise<void> {
// Get the node wrapper elements to scope handle search
const sourceNode = this.getNodeLocator(sourceNodeIndex);
const targetNode = this.getNodeLocator(targetNodeIndex);
// ReactFlow renders Handle components as .react-flow__handle elements
// Output handles have class .react-flow__handle-right (Position.Right)
// Input handles have class .react-flow__handle-left (Position.Left)
const sourceHandle = sourceNode
.locator(".react-flow__handle-right")
.first();
const targetHandle = targetNode.locator(".react-flow__handle-left").first();
// Get precise center coordinates using evaluate to avoid CSS transform issues
const getHandleCenter = async (locator: Locator) => {
const el = await locator.elementHandle();
if (!el) throw new Error("Handle element not found");
const rect = await el.evaluate((node) => {
const r = node.getBoundingClientRect();
return { x: r.x + r.width / 2, y: r.y + r.height / 2 };
});
return rect;
};
const source = await getHandleCenter(sourceHandle);
const target = await getHandleCenter(targetHandle);
// ReactFlow requires a proper drag sequence with intermediate moves
await this.page.mouse.move(source.x, source.y);
await this.page.mouse.down();
// Move in steps to trigger ReactFlow's connection detection
const steps = 20;
for (let i = 1; i <= steps; i++) {
const ratio = i / steps;
await this.page.mouse.move(
source.x + (target.x - source.x) * ratio,
source.y + (target.y - source.y) * ratio,
);
}
await this.page.mouse.up();
}
async getEdgeCount(): Promise<number> {
return await this.page.locator(".react-flow__edge").count();
}
// --- Save ---
async saveAgent(
name: string = "Test Agent",
description: string = "",
): Promise<void> {
console.log(`Saving agent '${name}' with description '${description}'`);
await this.page.getByTestId("save-control-save-button").click();
const nameInput = this.page.getByTestId("save-control-name-input");
await nameInput.waitFor({ state: "visible", timeout: 5000 });
await nameInput.fill(name);
if (description) {
await this.page
.getByTestId("save-control-description-input")
.fill(description);
}
await this.page.getByTestId("save-control-name-input").fill(name);
await this.page
.getByTestId("save-control-description-input")
.fill(description);
await this.page.getByTestId("save-control-save-agent-button").click();
}
async waitForSaveComplete(): Promise<void> {
await expect(this.page).toHaveURL(/flowID=/, { timeout: 15000 });
}
async getBlocksFromAPI(): Promise<Block[]> {
if (Object.keys(this.cachedBlocks).length > 0) {
return Object.values(this.cachedBlocks);
}
async waitForSaveButton(): Promise<void> {
await this.page.waitForSelector(
'[data-testid="save-control-save-button"]:not([disabled])',
{ timeout: 10000 },
console.log(`Getting blocks from API request`);
// Make direct API request using the page's request context
const response = await this.page.request.get(
"http://localhost:3000/api/proxy/api/blocks",
);
const apiBlocks: APIBlock[] = await response.json();
console.log(`Found ${apiBlocks.length} blocks from API`);
// Convert API blocks to test Block format
const blocks = apiBlocks.map((block) => ({
id: block.id,
name: block.name,
description: block.description,
type: block.uiType,
}));
this.cachedBlocks = blocks.reduce(
(acc, block) => {
acc[block.id] = block;
return acc;
},
{} as Record<string, Block>,
);
return blocks;
}
// --- Run ---
async getFilteredBlocksFromAPI(
filterFn: (block: Block) => boolean,
): Promise<Block[]> {
console.log(`Getting filtered blocks from API`);
const blocks = await this.getBlocksFromAPI();
return blocks.filter(filterFn);
}
async addBlock(block: Block): Promise<void> {
console.log(`Adding block ${block.name} (${block.id}) to agent`);
await this.openBlocksPanel();
const searchInput = this.page.locator(
'[data-id="blocks-control-search-bar"] input[type="text"]',
);
const displayName = this.getDisplayName(block.name);
await searchInput.clear();
await searchInput.fill(displayName);
const blockCardId = block.id.replace(/[^a-zA-Z0-9]/g, "");
const blockCard = this.page.locator(
`[data-id="block-card-${blockCardId}"]`,
);
await blockCard.waitFor({ state: "visible", timeout: 10000 });
await blockCard.click();
}
async hasBlock(_block: Block) {
// In the new flow editor, verify a node exists on the canvas
const node = this.page.locator('[data-id^="custom-node-"]').first();
await node.isVisible();
}
async getBlockInputs(blockId: string): Promise<string[]> {
console.log(`Getting block ${blockId} inputs`);
try {
const node = this.page.locator(`[data-blockid="${blockId}"]`).first();
const inputsData = await node.getAttribute("data-inputs");
return inputsData ? JSON.parse(inputsData) : [];
} catch (error) {
console.error("Error getting block inputs:", error);
return [];
}
}
async selectBlockCategory(category: string): Promise<void> {
console.log(`Selecting block category: ${category}`);
await this.page.getByText(category, { exact: true }).click();
// Wait for the blocks to load after category selection
await this.page.waitForTimeout(3000);
}
async getBlocksForCategory(category: string): Promise<Block[]> {
console.log(`Getting blocks for category: ${category}`);
// Clear any existing search to ensure we see all blocks in the category
const searchInput = this.page.locator(
'[data-id="blocks-control-search-bar"] input[type="text"]',
);
await searchInput.clear();
// Wait for search to clear
await this.page.waitForTimeout(300);
// Select the category first
await this.selectBlockCategory(category);
try {
const blockFinder = this.page.locator('[data-id^="block-card-"]');
await blockFinder.first().waitFor();
const blocks = await blockFinder.all();
console.log(`found ${blocks.length} blocks in category ${category}`);
const results = await Promise.all(
blocks.map(async (block) => {
try {
const fullId = (await block.getAttribute("data-id")) || "";
const id = fullId.replace("block-card-", "");
const nameElement = block.locator('[data-testid^="block-name-"]');
const descriptionElement = block.locator(
'[data-testid^="block-description-"]',
);
const name = (await nameElement.textContent()) || "";
const description = (await descriptionElement.textContent()) || "";
const type = (await nameElement.getAttribute("data-type")) || "";
return {
id,
name: name.trim(),
type: type.trim(),
description: description.trim(),
};
} catch (elementError) {
console.error("Error processing block:", elementError);
return null;
}
}),
);
// Filter out any null results from errors
return results.filter((block): block is Block => block !== null);
} catch (error) {
console.error(`Error getting blocks for category ${category}:`, error);
return [];
}
}
async _buildBlockSelector(blockId: string, dataId?: string): Promise<string> {
const selector = dataId
? `[data-id="${dataId}"] [data-blockid="${blockId}"]`
: `[data-blockid="${blockId}"]`;
return selector;
}
private async moveBlockToViewportPosition(
blockSelector: string,
options: { xRatio?: number; yRatio?: number } = {},
): Promise<void> {
const { xRatio = 0.5, yRatio = 0.5 } = options;
const blockLocator = this.page.locator(blockSelector).first();
await blockLocator.waitFor({ state: "visible" });
const boundingBox = await blockLocator.boundingBox();
const viewport = this.page.viewportSize();
if (!boundingBox || !viewport) {
return;
}
const currentX = boundingBox.x + boundingBox.width / 2;
const currentY = boundingBox.y + boundingBox.height / 2;
const targetX = viewport.width * xRatio;
const targetY = viewport.height * yRatio;
const distance = Math.hypot(targetX - currentX, targetY - currentY);
if (distance < 5) {
return;
}
await this.page.mouse.move(currentX, currentY);
await this.page.mouse.down();
await this.page.mouse.move(targetX, targetY, { steps: 15 });
await this.page.mouse.up();
await this.page.waitForTimeout(200);
}
async getBlockById(blockId: string, dataId?: string): Promise<Locator> {
console.log(`getting block ${blockId} with dataId ${dataId}`);
return this.page.locator(await this._buildBlockSelector(blockId, dataId));
}
// dataId is optional, if provided, it will start the search with that container, otherwise it will start with the blockId
// this is useful if you have multiple blocks with the same id, but different dataIds which you should have when adding a block to the graph.
// Do note that once you run an agent, the dataId will change, so you will need to update the tests to use the new dataId or not use the same block in tests that run an agent
async fillBlockInputByPlaceholder(
blockId: string,
placeholder: string,
value: string,
dataId?: string,
): Promise<void> {
console.log(
`filling block input ${placeholder} with value ${value} of block ${blockId}`,
);
const block = await this.getBlockById(blockId, dataId);
const input = block.getByPlaceholder(placeholder);
await input.fill(value);
}
async selectBlockInputValue(
blockId: string,
inputName: string,
value: string,
dataId?: string,
): Promise<void> {
console.log(
`selecting value ${value} for input ${inputName} of block ${blockId}`,
);
// First get the button that opens the dropdown
const baseSelector = await this._buildBlockSelector(blockId, dataId);
// Find the combobox button within the input handle container
const comboboxSelector = `${baseSelector} [data-id="input-handle-${inputName.toLowerCase()}"] button[role="combobox"]`;
try {
// Click the combobox to open it
await this.page.click(comboboxSelector);
// Wait a moment for the dropdown to open
await this.page.waitForTimeout(100);
// Select the option from the dropdown
// The actual selector for the option might need adjustment based on the dropdown structure
await this.page.getByRole("option", { name: value }).click();
} catch (error) {
console.error(
`Error selecting value "${value}" for input "${inputName}":`,
error,
);
throw error;
}
}
async fillBlockInputByLabel(
blockId: string,
label: string,
value: string,
): Promise<void> {
console.log(`filling block input ${label} with value ${value}`);
const block = await this.getBlockById(blockId);
const input = block.getByLabel(label);
await input.fill(value);
}
async connectBlockOutputToBlockInputViaDataId(
blockOutputId: string,
blockInputId: string,
): Promise<void> {
console.log(
`connecting block output ${blockOutputId} to block input ${blockInputId}`,
);
try {
// Locate the output element
const outputElement = this.page.locator(`[data-id="${blockOutputId}"]`);
// Locate the input element
const inputElement = this.page.locator(`[data-id="${blockInputId}"]`);
await outputElement.dragTo(inputElement);
} catch (error) {
console.error("Error connecting block output to input:", error);
}
}
async connectBlockOutputToBlockInputViaName(
startBlockId: string,
startBlockOutputName: string,
endBlockId: string,
endBlockInputName: string,
startDataId?: string,
endDataId?: string,
): Promise<void> {
console.log(
`connecting block output ${startBlockOutputName} of block ${startBlockId} to block input ${endBlockInputName} of block ${endBlockId}`,
);
const startBlockBase = await this._buildBlockSelector(
startBlockId,
startDataId,
);
const endBlockBase = await this._buildBlockSelector(endBlockId, endDataId);
await this.moveBlockToViewportPosition(startBlockBase, { xRatio: 0.35 });
await this.moveBlockToViewportPosition(endBlockBase, { xRatio: 0.65 });
const startBlockOutputSelector = `${startBlockBase} [data-testid="output-handle-${startBlockOutputName.toLowerCase()}"]`;
const endBlockInputSelector = `${endBlockBase} [data-testid="input-handle-${endBlockInputName.toLowerCase()}"]`;
console.log("Start block selector:", startBlockOutputSelector);
console.log("End block selector:", endBlockInputSelector);
const startElement = this.page.locator(startBlockOutputSelector);
const endElement = this.page.locator(endBlockInputSelector);
await startElement.scrollIntoViewIfNeeded();
await this.page.waitForTimeout(200);
await endElement.scrollIntoViewIfNeeded();
await this.page.waitForTimeout(200);
await startElement.dragTo(endElement);
}
async isLoaded(): Promise<boolean> {
console.log(`checking if build page is loaded`);
try {
await this.page.waitForLoadState("domcontentloaded", { timeout: 10_000 });
return true;
} catch {
return false;
}
}
async isRunButtonEnabled(): Promise<boolean> {
console.log(`checking if run button is enabled`);
const runButton = this.page.locator('[data-id="run-graph-button"]');
return await runButton.isEnabled();
}
async clickRunButton(): Promise<void> {
async runAgent(): Promise<void> {
console.log(`clicking run button`);
const runButton = this.page.locator('[data-id="run-graph-button"]');
await runButton.click();
await this.page.waitForTimeout(1000);
await runButton.click();
}
// --- Undo / Redo ---
async isUndoEnabled(): Promise<boolean> {
const btn = this.page.locator('[data-id="undo-button"]');
return !(await btn.isDisabled());
}
async isRedoEnabled(): Promise<boolean> {
const btn = this.page.locator('[data-id="redo-button"]');
return !(await btn.isDisabled());
}
async clickUndo(): Promise<void> {
await this.page.locator('[data-id="undo-button"]').click();
}
async clickRedo(): Promise<void> {
await this.page.locator('[data-id="redo-button"]').click();
}
// --- Copy / Paste ---
async copyViaKeyboard(): Promise<void> {
const isMac = process.platform === "darwin";
await this.page.keyboard.press(isMac ? "Meta+c" : "Control+c");
}
async pasteViaKeyboard(): Promise<void> {
const isMac = process.platform === "darwin";
await this.page.keyboard.press(isMac ? "Meta+v" : "Control+v");
}
// --- Helpers ---
async fillBlockInputByPlaceholder(
placeholder: string,
value: string,
nodeIndex: number = 0,
): Promise<void> {
const node = this.getNodeLocator(nodeIndex);
const input = node.getByPlaceholder(placeholder);
await input.fill(value);
}
async clickCanvas(): Promise<void> {
const pane = this.page.locator(".react-flow__pane").first();
const box = await pane.boundingBox();
if (box) {
// Click in the center of the canvas to avoid sidebar/toolbar overlaps
await pane.click({
position: { x: box.width / 2, y: box.height / 2 },
});
} else {
await pane.click();
async fillRunDialog(inputs: Record<string, string>): Promise<void> {
console.log(`filling run dialog`);
for (const [key, value] of Object.entries(inputs)) {
await this.page.getByTestId(`agent-input-${key}`).fill(value);
}
}
getPlaywrightPage(): Page {
return this.page;
async clickRunDialogRunButton(): Promise<void> {
console.log(`clicking run button`);
await this.page.getByTestId("agent-run-button").click();
}
async createDummyAgent(): Promise<void> {
async waitForCompletionBadge(): Promise<void> {
console.log(`waiting for completion badge`);
await this.page.waitForSelector(
'[data-id^="badge-"][data-id$="-COMPLETED"]',
);
}
async waitForSaveButton(): Promise<void> {
console.log(`waiting for save button`);
await this.page.waitForSelector(
'[data-testid="save-control-save-button"]:not([disabled])',
);
}
async isCompletionBadgeVisible(): Promise<boolean> {
console.log(`checking for completion badge`);
const completionBadge = this.page
.locator('[data-id^="badge-"][data-id$="-COMPLETED"]')
.first();
return await completionBadge.isVisible();
}
async waitForVersionField(): Promise<void> {
console.log(`waiting for version field`);
// wait for the url to have the flowID
await this.page.waitForSelector(
'[data-testid="save-control-version-output"]',
);
}
async getDictionaryBlockDetails(): Promise<Block> {
return {
id: "dummy-id-1",
name: "Add to Dictionary",
description: "Add to Dictionary",
type: "Standard",
};
}
async getCalculatorBlockDetails(): Promise<Block> {
return {
id: "dummy-id-2",
name: "Calculator",
description: "Calculator",
type: "Standard",
};
}
async waitForSaveDialogClose(): Promise<void> {
console.log(`waiting for save dialog to close`);
await this.page.waitForSelector(
'[data-id="save-control-popover-content"]',
{ state: "hidden" },
);
}
async getGithubTriggerBlockDetails(): Promise<Block[]> {
return [
{
id: "6c60ec01-8128-419e-988f-96a063ee2fea",
name: "Github Trigger",
description:
"This block triggers on pull request events and outputs the event type and payload.",
type: "Standard",
},
{
id: "551e0a35-100b-49b7-89b8-3031322239b6",
name: "Github Star Trigger",
description:
"This block triggers on star events and outputs the event type and payload.",
type: "Standard",
},
{
id: "2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
name: "Github Release Trigger",
description:
"This block triggers on release events and outputs the event type and payload.",
type: "Standard",
},
{
id: "b2605464-e486-4bf4-aad3-d8a213c8a48a",
name: "Github Issue Trigger",
description:
"This block triggers on issue events and outputs the event type and payload.",
type: "Standard",
},
{
id: "87f847b3-d81a-424e-8e89-acadb5c9d52b",
name: "Github Discussion Trigger",
description:
"This block triggers on discussion events and outputs the event type and payload.",
type: "Standard",
},
];
}
async nextTutorialStep(): Promise<void> {
console.log(`clicking next tutorial step`);
await this.page.getByRole("button", { name: "Next" }).click();
}
async getBlocksToSkip(): Promise<string[]> {
return [
(await this.getGithubTriggerBlockDetails()).map((b) => b.id),
// MCP Tool block requires an interactive dialog (server URL + OAuth) before
// it can be placed, so it can't be tested via the standard "add block" flow.
"a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
].flat();
}
async createDummyAgent() {
await this.closeTutorial();
await this.addBlockByClick("Add to Dictionary");
await this.waitForNodeOnCanvas(1);
await this.openBlocksPanel();
const searchInput = this.page.locator(
'[data-id="blocks-control-search-bar"] input[type="text"]',
);
await searchInput.clear();
await searchInput.fill("Add to Dictionary");
const blockCard = this.page.locator('[data-id^="block-card-"]').first();
try {
await blockCard.waitFor({ state: "visible", timeout: 10000 });
await blockCard.click();
} catch (error) {
console.log("Could not find Add to Dictionary block:", error);
}
await this.saveAgent("Test Agent", "Test Description");
await this.waitForSaveComplete();
}
}