Compare commits

...

91 Commits

Author SHA1 Message Date
Zamil Majdy
9684b99949 fix(backend/copilot): fix logging, token counting, and compaction edge cases
- Convert all remaining f-string logger calls to %-style across SDK files
- Fix _msg_tokens not counting Anthropic text blocks (underestimated tokens)
- Fix _truncate_middle_tokens producing wrong output when max_tok < 3
- Fix CompactionTracker.reset_for_query not clearing _compact_start event
- Add cycle guard to strip_progress_entries reparenting loop
- Replace bare except with logged exception in transcript metadata loading
2026-03-14 21:58:42 +07:00
Zamil Majdy
d00059dc94 test(backend/copilot): add tests for transcript compaction helpers
Add comprehensive unit tests for the JSONL-message conversion layer:
- _flatten_assistant_content and _flatten_tool_result_content
- _transcript_to_messages: strippable types, compact summaries, edge cases
- _messages_to_transcript: structure, parentUuid chain, validation
- Roundtrip: messages to transcript to messages content preservation
- compact_transcript: too few messages, mock compaction, no-op, failure
2026-03-14 04:34:32 +07:00
Zamil Majdy
3a14077d52 refactor(backend/copilot): remove dead delete_transcript function
delete_transcript is no longer needed — oversized transcripts are now
compacted proactively at download time via _maybe_compact_and_upload
(400KB threshold), preventing the "prompt too long" error before it
can occur. If compaction fails, resume is gracefully skipped.
2026-03-14 03:58:36 +07:00
Zamil Majdy
4446be94ae test(backend/copilot): add e2e compaction lifecycle tests
Exercises the full service.py compaction flow end-to-end:
TranscriptBuilder load → CompactionTracker state machine →
read compacted entries → replace_entries → export → roundtrip.
2026-03-14 01:38:11 +07:00
Zamil Majdy
983aed2b0a fix(backend/copilot): preserve isCompactSummary through transcript roundtrip
TranscriptEntry didn't carry isCompactSummary, so exported JSONL had
bare "type": "summary" entries. On next turn's load_previous, these
were stripped as regular summaries, losing compaction context.

- Add isCompactSummary field to TranscriptEntry
- Preserve isCompactSummary entries in load_previous, strip_progress_entries,
  and _transcript_to_messages
- Add 6 roundtrip tests verifying isCompactSummary survives export→reload
2026-03-14 01:11:56 +07:00
Zamil Majdy
66a8cf69be fix(backend/copilot): async file I/O, add TranscriptBuilder tests
- Convert read_cli_session_file to async using aiofiles (addresses
  review comment about sync I/O in async streaming loop)
- Add 10 tests for TranscriptBuilder (replace_entries, load_previous,
  append operations, message ID consistency)
- Update existing read_cli_session_file tests to async
2026-03-14 00:16:51 +07:00
Zamil Majdy
d1b8766fa4 fix(platform): restore compaction code lost in merge, address remaining review comments
The merge of feat/tracking-cost-block reverted transcript compaction code
and several review fixes from 90b7edf1f. This commit restores the lost
code and applies additional improvements requested in review:

- Restore transcript compaction functions (_transcript_to_messages,
  _messages_to_transcript, compact_transcript, _flatten_* helpers)
- Restore _maybe_compact_and_upload helper in service.py to flatten
  deep nesting (5 levels -> 2) in transcript compaction block
- Restore CLI session file reading (read_cli_session_file) for
  mid-stream compaction sync
- Restore total_tokens DRY fix (compute once, reuse in finally)
- Extract _run_compression() helper to eliminate nested try blocks
- Add STOP_REASON_END_TURN, COMPACT_MSG_ID_PREFIX, ENTRY_TYPE_MESSAGE
  named constants replacing magic strings
- Add MS_PER_MINUTE and MS_PER_HOUR constants in UsageLimits.tsx
- Add docstring explaining Monday edge case in _weekly_reset_time
2026-03-13 22:56:19 +07:00
Zamil Majdy
628b779128 fix(backend): preserve at least one assistant message during middle-out deletion
The middle-out deletion step in compress_context could remove all assistant
messages when client=None (fallback compaction), causing validate_transcript
to fail and returning None (context loss). Now skips deleting the last
remaining assistant message.
2026-03-13 22:09:28 +07:00
Zamil Majdy
90b7edf1f1 fix(backend/copilot): address review comments — top-level imports, bug fixes, refactoring
- Move all local imports to top-level (transcript.py, helpers.py,
  baseline/service.py) per project style rules
- Fix sub.get("text") bug: use `or` instead of default arg to handle
  None values in tool_result content blocks
- Extract _flatten_assistant_content and _flatten_tool_result_content
  helpers from _transcript_to_messages to reduce nesting
- Use early returns in _get_credits/_spend_credits
- Extract _fetch_counters helper in rate_limit.py to DRY the Redis
  fetch pattern between get_usage_status and check_rate_limit
- Fix DRY violation: compute total_tokens once before StreamUsage,
  reuse in finally block for session persistence
- Fix chunk.choices guard: use early continue instead of ternary
  to prevent IndexError on empty list
- Make generic error message in execute_block to avoid leaking internals
2026-03-13 19:49:28 +07:00
Zamil Majdy
8b970c4c3d Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction-v2 2026-03-13 19:25:12 +07:00
Zamil Majdy
601fed93b8 fix(backend/copilot): resolve stash conflicts in transcript.py, handle file-stat race condition 2026-03-13 19:10:07 +07:00
Zamil Majdy
a8259ca935 feat(analytics): read-only SQL views layer with analytics schema (#12367)
### Changes 🏗️

Adds `autogpt_platform/analytics/` — 14 SQL view definitions that expose
production data safely through a locked-down `analytics` schema.

**Security model:**
- Views use `security_invoker = false` (PostgreSQL 15+), so they execute
as their owner (`postgres`), not the caller
- `analytics_readonly` role only has access to `analytics.*` — cannot
touch `platform` or `auth` tables directly

**Files:**
- `backend/generate_views.py` — does everything; auto-reads credentials
from `backend/.env`
- `analytics/queries/*.sql` — 14 documented view definitions (auth, user
activity, executions, onboarding funnel, cohort retention)

---

### Running locally (dev)

```bash
cd autogpt_platform/backend

# First time only — creates analytics schema, role, grants
poetry run analytics-setup

# Create / refresh views (auto-reads backend/.env)
poetry run analytics-views
```

### Running in production (Supabase)

```bash
cd autogpt_platform/backend

# Step 1 — first time only (run in Supabase SQL Editor as postgres superuser)
poetry run analytics-setup --dry-run
# Paste the output into Supabase SQL Editor and run

# Step 2 — apply views (use direct connection host, not pooler)
poetry run analytics-views --db-url "postgresql://postgres:PASSWORD@db.<ref>.supabase.co:5432/postgres"

# Step 3 — set password for analytics_readonly so external tools can connect
# Run in Supabase SQL Editor:
# ALTER ROLE analytics_readonly WITH PASSWORD 'your-password';
```

---

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Setup + views applied cleanly on local Postgres 15
- [x] `analytics_readonly` can `SELECT` from all 14 `analytics.*` views
- [x] `analytics_readonly` gets `permission denied` on `platform.*` and
`auth.*` directly

---------

Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-03-13 12:04:42 +00:00
Zamil Majdy
e3f9fa3648 fix(platform): merge dev branch, resolve conflicts in routes_test and openapi.json 2026-03-13 19:03:20 +07:00
Zamil Majdy
809ba56f0b fix(backend/copilot): preserve structured tool_result content during transcript flattening
When flattening tool_result blocks for summarisation, dict blocks without a
"text" key were silently dropped (replaced with empty string). Now falls back
to json.dumps(sub) so JSON/structured payloads are preserved for compaction.
2026-03-13 18:30:23 +07:00
Zamil Majdy
9a467e1dba fix(platform): preserve message_count after transcript compaction, clean up imports
- Fix bug where message_count was reset to JSONL line count after
  compaction instead of preserving the original session.messages
  watermark. The JSONL line count is smaller than the original message
  count, causing the gap-fill logic to re-inject already-covered
  messages into the prompt.
- Move pathlib.Path and uuid.uuid4 imports to module level in
  transcript.py.
- Use Next.js router instead of window.location.href in credits page.
2026-03-13 18:21:15 +07:00
Zamil Majdy
0200748225 fix(platform): address PR review comments
- transcript.py: properly flatten tool_result content blocks including
  nested list structures instead of losing tool_use_id silently
- transcript_builder.py: make replace_entries atomic — parse into temp
  builder first, only swap on success
- service.py: skip resume when compaction fails (avoid re-triggering
  "Prompt is too long"), wrap upload in try/except for best-effort
- baseline/service.py: use floor of 0 instead of 1 for token estimates
- rate_limit.py: broaden Redis exception handling to catch all errors
  (timeouts, etc.) for true fail-open behavior; remove unused import
- helpers.py: return ErrorResponse on post-exec InsufficientBalanceError
  instead of swallowing; remove raw exception text from user message
- helpers_test.py: update test to expect ErrorResponse
- UsageLimits.tsx: remove dark:* classes (copilot has no dark mode yet)
2026-03-13 18:12:24 +07:00
Swifty
1f1288d623 feat(copilot): generate personalized quick-action prompts from Tally business understanding (#12374)
During Tally data extraction, the system now also generates personalized
quick-action prompts as part of the existing LLM extraction call
(configurable model, defaults to GPT-4o-mini, `temperature=0.0`). The
prompt asks the LLM for 5 candidates, then the code validates (filters
prompts >20 words) and keeps the top 3. These prompts are stored in the
existing `CoPilotUnderstanding.data` JSON field (at the top level, not
under `business`) and served to the frontend via a new API endpoint. The
copilot chat page uses them instead of hardcoded defaults when
available.

### Changes 🏗️

**Backend – Data models** (`understanding.py`):
- Added `suggested_prompts` field to `BusinessUnderstandingInput`
(optional) and `BusinessUnderstanding` (default empty list)
- Updated `from_db()` to deserialize `suggested_prompts` from top-level
of the data JSON
- Updated `merge_business_understanding_data()` with overwrite strategy
for prompts (full replace, not append)
- `format_understanding_for_prompt()` intentionally does **not** include
`suggested_prompts` — they are UI-only

**Backend – Prompt generation** (`tally.py`):
- Extended `_EXTRACTION_PROMPT` to request 5 suggested prompts alongside
the existing business understanding fields — all extracted in a single
LLM call (`temperature=0.0`)
- Post-extraction validation filters out prompts exceeding 20 words and
slices to the top 3
- Model is now configurable via `tally_extraction_llm_model` setting
(defaults to `openai/gpt-4o-mini`)

**Backend – API endpoint** (`routes.py`):
- Added `GET /api/chat/suggested-prompts` (auth required)
- Returns `{prompts: string[]}` from the user's cached business
understanding (48h Redis TTL)
- Returns empty array if no understanding or no prompts exist

**Frontend** (`EmptySession/`):
- `helpers.ts`: Extracted defaults to `DEFAULT_QUICK_ACTIONS`,
`getQuickActions()` now accepts optional custom prompts and falls back
to defaults
- `EmptySession.tsx`: Calls `useGetV2GetSuggestedPrompts` hook
(`staleTime: Infinity`) and passes results to `getQuickActions()` with
hardcoded fallback
- Fixed `useEffect` resize handler that previously used
`window.innerWidth` as a dependency (re-ran every render); now uses a
proper resize event listener
- Added skeleton loading state while prompts are being fetched

**Generated** (`__generated__/`):
- Regenerated Orval API client with new endpoint types and hooks

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Backend format + lint + pyright pass
  - [x] Frontend format + lint pass
  - [x] All existing tally tests pass (28/28)
  - [x] All chat route tests pass (9/9)
  - [x] All invited_user tests pass (7/7)
- [x] E2E: New user with tally data sees custom prompts on copilot page
  - [x] E2E: User without tally data sees hardcoded default prompts
  - [x] E2E: Clicking a custom prompt sends it as a chat message
2026-03-13 12:11:31 +01:00
Zamil Majdy
1704214394 fix(backend/copilot): recount message_count after transcript compaction
After compacting a transcript at download time, the message_count was
stale (from the original download), causing the gap-fill logic to miss
messages on subsequent turns. Now recount from the compacted content
before uploading.
2026-03-13 18:01:48 +07:00
Zamil Majdy
f2efd3ad7f fix(backend/copilot): address review - path traversal fix, pathlib, tests
- Use pathlib.Path.glob instead of import glob
- Add realpath + prefix check on glob results to prevent symlink escapes
- Add unit tests for _cli_project_dir, read_cli_session_file,
  _transcript_to_messages, and _messages_to_transcript
2026-03-13 17:51:45 +07:00
Zamil Majdy
ee841d1515 fix(backend/copilot): make TranscriptBuilder compaction-aware and compact oversized transcripts
TranscriptBuilder accumulated all messages including pre-compaction content,
causing uploaded transcripts to grow unbounded. When the CLI compacted
mid-stream, TranscriptBuilder kept the full uncompacted history. On the next
turn, --resume would fail with "Prompt is too long".

Fix 1 (prevent): After the CLI's PreCompact hook fires and compaction ends,
read the CLI's session file (which reflects compaction) and replace
TranscriptBuilder's entries via new replace_entries() method.

Fix 2 (mitigate): At download time, if transcript exceeds 400KB threshold,
compact it using compress_context (LLM summarization + truncation fallback)
before passing to --resume. Upload the compacted version for future turns.
2026-03-13 17:43:48 +07:00
Zamil Majdy
5966d3669d fix(platform): use round() for cache weights, accept string dates in UsageLimits 2026-03-13 17:38:25 +07:00
Otto
02645732b8 feat(backend/copilot): enable E2B auto_resume and reduce safety-net timeout (#12397)
Enable E2B `auto_resume` lifecycle option and reduce the safety-net
timeout from 3 hours to 5 minutes.

Currently, if the explicit per-turn `pause_sandbox_direct()` call fails
(process crash, network issue, fire-and-forget task cancellation), the
sandbox keeps running for up to **3 hours** before the safety-net
timeout fires. With this change, worst-case billing drops to **5
minutes**.

### Changes
- Add `auto_resume: True` to sandbox lifecycle config — paused sandboxes
wake transparently on SDK activity
- Reduce `e2b_sandbox_timeout` default from 10800s (3h) → 300s (5min)
- Add `e2b_sandbox_auto_resume` config field (default: `True`)
- Guard: `auto_resume` only added when `on_timeout == "pause"`

### What doesn't change
- Explicit per-turn `pause_sandbox_direct()` remains the primary
mechanism
- `connect()` / `_try_reconnect()` flow unchanged
- Redis key management unchanged
- No latency impact (resume is ~1-2s regardless of trigger)

### Risk
Very low — `auto_resume` is additive. If it doesn't work as advertised,
`connect()` still resumes paused sandboxes exactly as before.

Ref: https://e2b.dev/docs/sandbox/auto-resume
Linear: SECRT-2118

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
2026-03-13 10:29:28 +00:00
Zamil Majdy
c81ab1fc3b fix(backend/copilot): use adapter pattern for credit operations in executor
The CoPilot executor runs without a Prisma connection, so direct calls
to get_user_credit_model() caused "Client is not connected to the query
engine" errors. Replace with _get_credits/_spend_credits adapters that
fall back to RPC via DatabaseManagerAsyncClient when Prisma is unavailable.
Also add missing spend_credits/get_credits to DatabaseManagerAsyncClient.
2026-03-13 17:14:08 +07:00
Swifty
ba301a3912 feat(platform): add whitelisting-backed beta user provisioning (#12347)
### Changes 🏗️

- add invite-backed beta provisioning with a new `InvitedUser` platform
model, Prisma migration, and first-login activation path that
materializes `User`, `Profile`, `UserOnboarding`, and
`CoPilotUnderstanding`
- replace the legacy beta allowlist check with invite-backed gating for
email/password signup and Tally pre-seeding during activation
- add admin backend APIs and frontend `/admin/users` management UI for
listing, creating, revoking, retrying, and bulk-uploading invited users
- add the design doc for the beta invite system and extend backend
coverage for invite activation, bulk uploads, and auth-route behavior
- configuration changes: introduce the new invite/tally schema objects
and migration; no new env vars or docker service changes are required

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `cd autogpt_platform/backend && poetry run format`
- [x] `cd autogpt_platform/backend && poetry run pytest -q` (run against
an isolated local Postgres database with non-conflicting service port
overrides)

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2026-03-13 10:25:49 +01:00
Zamil Majdy
5446c7f18f fix(platform): consistent header icon sizing, halve rate limits
- Add size="icon" to SidebarTrigger for uniform h-9 w-9 button sizing
- Remove extra positioning wrappers (relative left-1, left-5) around header icons
- Halve daily token limit to 2.5M and weekly to 12.5M for more reasonable defaults
2026-03-13 15:55:18 +07:00
Zamil Majdy
2b0c9ba703 fix(frontend): use Button component for icon buttons, add timezone and <1% display
- UsageLimits and NotificationToggle now use <Button variant="ghost" size="icon">
  to match SidebarTrigger's padding/sizing
- Weekly reset time shows timezone abbreviation (e.g., "Mon 7:00 AM PST")
- Usage below 1% shows "<1% used" instead of "0% used" with a 1% min bar width
2026-03-13 15:42:05 +07:00
Zamil Majdy
195c7011ae fix(frontend): update usage after chat, show reset day, fix icon weight
- Invalidate usage query when stream completes so the usage bar
  updates immediately after chatting instead of waiting 30s.
- Show reset time as day/time in local timezone when over 24h away
  (e.g. "Mon 12:00 AM") instead of unclear "63h 41m".
- Use weight="light" on ChartBar icon to match other header icons.
2026-03-13 15:21:42 +07:00
Zamil Majdy
d4944fb22b fix(platform): emit StreamUsage as SSE comment, move usage to popover
StreamUsage events crashed the frontend because the Vercel AI SDK uses
z.strictObject() and rejects unknown event types. Fix by overriding
to_sse() to emit as an SSE comment (invisible to the parser). Usage
data is already recorded server-side (session DB + Redis counters).

Move usage limits from sidebar footer back to a ChartBar icon button
in the sidebar header that opens a popover on click.
2026-03-13 15:14:14 +07:00
Zamil Majdy
a5ed8fefa9 feat(copilot): cost-weighted token rate limiting with cache breakdown
- Rate limiter now uses Anthropic's cost model: cache_read at 10%,
  cache_creation at 25%, uncached and output at 100%
- Track cache_read_tokens and cache_creation_tokens separately in
  Usage model, StreamUsage response, and SDK token extraction
- Pass cache breakdown through to record_token_usage() for accurate
  weighted counting
- Add test for cost-weighted counting (10K cache_read → 1K weighted)

This makes multi-turn conversations fairer: cached system prompts
and tool schemas don't penalize users at full token cost.
2026-03-13 14:36:04 +07:00
Zamil Majdy
a52a777b29 fix(copilot): increase rate limits from 500K/5M to 5M/25M daily/weekly
A single CoPilot turn consumes ~10-15K tokens (system prompt + tool
schemas), so 500K daily only allowed ~35-50 turns which is too
restrictive for normal use.
2026-03-13 14:21:01 +07:00
Zamil Majdy
8bec7a6933 fix(frontend): move usage limits to left column on billing page
Move CoPilot Usage Limits below Automatic Refill Settings in the left
column and add "Open CoPilot" button for consistency with other sections.
2026-03-13 14:17:57 +07:00
Zamil Majdy
e73791efed fix(copilot): move usage limits to sidebar bottom, add to billing page
- Move UsageLimits from top-right headerSlot to left sidebar footer
- Show usage limits on both main copilot page and inside chat sessions
- Add CoPilot Usage Limits section to billing/credits page
- Change "Learn more about usage limits" to "Manage billing & credits"
- Remove unused headerSlot prop from ChatContainer/ChatMessagesContainer
- Clean up unused imports in CopilotPage
2026-03-13 14:08:57 +07:00
Zamil Majdy
2d161ce2b9 refactor(platform): replace per-session rate limits with daily fixed-window
Per-session limits were gameable (create new session to reset) and had
confusing reset semantics (12h inactivity TTL that refreshed on use).
Replace with daily fixed-window counter that resets at midnight UTC,
matching the weekly window pattern.

- session_token_limit → daily_token_limit (500K tokens/day)
- Redis key: copilot:usage:daily:{user_id}:{YYYY-MM-DD} with TTL
- Remove session_id from usage API endpoint and record_token_usage()
- Frontend: "Current session" → "Today", "Weekly limits" → "This week"
2026-03-13 13:22:59 +07:00
Zamil Majdy
6fc4989654 fix(copilot): include full context window in fallback token estimation
Each API call sends the complete openai_messages list (system prompt +
history + turn), so the fallback estimator should count all of it to
match what prompt_tokens would have reported.
2026-03-13 12:36:51 +07:00
Abhimanyu Yadav
0cd9c0d87a fix(frontend): show sub-folders when navigating inside a folder (#12316)
## Summary

When opening a folder in the library, sub-folders were not displayed —
only agents were shown. This was caused by two issues:

1. The folder list query always fetched root-level folders (no
`parent_id` filter), so sub-folders were never requested
2. `showFolders` was set to `false` whenever a folder was selected,
hiding all folders from the view

### Changes 🏗️

- Pass `parent_id` to the `useGetV2ListLibraryFolders` hook so it
fetches child folders of the currently selected folder
- Remove the `!selectedFolderId` condition from `showFolders` so folders
render inside other folders
- Fetch the current folder via `useGetV2GetFolder` instead of searching
the (now differently-scoped) folder list
- Clean up breadcrumb: remove emoji icon, match folder name text size to
"My Library", replace `Button` with plain `<button>` to remove extra
padding/gap

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open a folder in the library and verify sub-folders are displayed
  - [x] Verify agents inside the folder still display correctly
- [x] Verify breadcrumb shows folder name without emoji, matching "My
Library" text size
- [x] Verify clicking "My Library" in breadcrumb navigates back to root
  - [x] Verify root-level view still shows all top-level folders
  - [x] Verify favorites tab does not show folders
2026-03-13 04:40:09 +00:00
Zamil Majdy
976443bf6e refactor(copilot): use tiktoken for fallback token estimation
Replace rough chars/4 heuristic with proper tiktoken tokenizer via
estimate_token_count/estimate_token_count_str from backend.util.prompt.
2026-03-13 05:24:53 +07:00
Zamil Majdy
4ceb15b3f1 fix(copilot): scope fallback token estimation to current turn only
The fallback estimator was counting the entire openai_messages history
(system prompt + all previous turns) instead of just the messages added
during the current turn. This caused overcounting and overly strict
rate limiting when providers don't return streaming usage data.
2026-03-13 03:44:30 +07:00
Zamil Majdy
3096f94996 feat(copilot): set default rate limits based on observed usage data
Session: 500K tokens (P90 session usage ~550K)
Weekly: 5M tokens (~10x heaviest observed weekly user)
2026-03-13 03:34:05 +07:00
Zamil Majdy
6f90729612 merge: resolve conflicts with dev (coerce_inputs_to_schema)
Merge dev's coerce_inputs_to_schema into execute_block alongside
credit charging. Both features coexist: coercion runs before
credit check. Test file combines both test suites.
2026-03-13 03:15:37 +07:00
Zamil Majdy
ebf89dde8b fix(copilot): include cached tokens in SDK token tracking
Anthropic's API reports cached tokens separately (cache_read_input_tokens,
cache_creation_input_tokens) from input_tokens. The previous code only read
input_tokens, undercounting total tokens for rate limiting.

OpenRouter (baseline path) already includes cached tokens in prompt_tokens
per OpenAI-compatible format — added clarifying comment.
2026-03-13 02:46:57 +07:00
Zamil Majdy
a083493aa2 fix(backend/copilot): auto-parse structured data and robust type coercion (#12390)
The copilot's `@@agptfile:` reference system always produces strings
when expanding
file references. This breaks blocks that expect structured types — e.g.
`GoogleSheetsWriteBlock` expects `values: list[list[str]]`, but receives
a raw CSV
string instead. Additionally, the copilot's input coercion was
duplicating logic from
the executor instead of reusing the shared `convert()` utility, and the
coercion had
no type-aware gating — it would always call `convert()`, which could
incorrectly
transform values that already matched the expected type (e.g.
stringifying a valid
`list[str]` in a `str | list[str]` union).

### Changes 🏗️

**Structured data parsing for `@@agptfile:` bare references:**
- When an entire tool argument value is a bare `@@agptfile:` reference,
the resolved
content is now auto-parsed: JSON → native types, CSV/TSV →
`list[list[str]]`
- Embedded references within larger strings still do plain text
substitution
- Updated copilot system prompt to document the structured data
capability

**Shared type coercion utility (`coerce_inputs_to_schema`):**
- Extracted `coerce_inputs_to_schema()` into `backend/util/type.py` —
shared by both
  the executor's `validate_exec()` and the copilot's `execute_block()`
- Uses Pydantic `model_fields` (not `__annotations__`) to include
inherited fields
- Added `_value_satisfies_type()` gate: only calls `convert()` when the
value doesn't
already match the target type, including recursive inner-element
checking for generics

**`_value_satisfies_type` — recursive type checking:**
- Handles `Any`, `Optional`, `Union`, `list[T]`, `dict[K,V]`, `set[T]`,
`tuple[T, ...]`,
  heterogeneous `tuple[str, int, bool]`, bare generics, nested generics
- Guards against non-runtime origins (`Literal`, etc.) to prevent
`isinstance()` crashes
- Returns `False` (not `True`) for unhandled generic origins as a safe
fallback

**Test coverage:**
- 51 new tests for `_value_satisfies_type` and `coerce_inputs_to_schema`
in `type_test.py`
- 8 new tests for `execute_block` type coercion in `helpers_test.py`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All existing file_ref tests pass
- [x] All new type_test.py tests pass (51 tests covering
_value_satisfies_type and coerce_inputs_to_schema)
- [x] All new helpers_test.py tests pass (8 tests covering execute_block
coercion)
  - [x] `poetry run format` passes clean
  - [x] `poetry run lint` passes clean
  - [x] Pyright type checking passes
2026-03-12 19:27:41 +00:00
Zamil Majdy
c51dc7ad99 fix(backend): agent generator sets invalid model on PerplexityBlocks (#12391)
Fixes the agent generator setting `gpt-5.2-2025-12-11` (or `gpt-4o`) as
the model for PerplexityBlocks instead of valid Perplexity models,
causing 100% failure rate for agents using Perplexity blocks.

### Changes 🏗️

- **Fixer: block-aware model validation** — `fix_ai_model_parameter()`
now reads the block's `inputSchema` to check for `enum` constraints on
the model field. Blocks with their own model enum (PerplexityBlock,
IdeogramBlock, CodexBlock, etc.) are validated against their own allowed
values with the correct default, instead of the hardcoded generic set
(`gpt-4o`, `claude-opus-4-6`). This also fixes `edit_agent` which runs
through the same fixer pipeline.
- **PerplexityBlock: runtime fallback** — Added a `field_validator` on
the model field that gracefully falls back to `SONAR` instead of
crashing when an invalid model value is encountered at runtime. Also
overrides `validate_data` to sanitize invalid model values *before* JSON
schema validation (which runs in `Block._execute` before Pydantic
instantiation), ensuring the fallback is actually reachable during block
execution.
- **DB migration** — Fixes existing PerplexityBlock nodes with invalid
model values in both `AgentNode.constantInput` and
`AgentNodeExecutionInputOutput` (preset overrides), matching the pattern
from the Gemini migration.
- **Tests** — Fixer tests for block-specific enum validation, plus
`validate_data`-level tests ensuring invalid models are sanitized before
JSON schema validation rejects them.

Resolves [SECRT-2097](https://linear.app/autogpt/issue/SECRT-2097)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All existing + new fixer tests pass
  - [x] PerplexityBlock block test passes
- [x] 11 perplexity_test.py tests pass (field_validator + validate_data
paths)
- [x] Verified invalid model (`gpt-5.2-2025-12-11`) falls back to
`perplexity/sonar` at runtime
  - [x] Verified valid Perplexity models are preserved by the fixer
  - [x] Migration covers both constantInput and preset overrides
2026-03-12 18:54:18 +00:00
Zamil Majdy
5d057e97e5 fix(copilot): move StreamUsage/StreamFinish back out of finally block
PEP 525 prohibits yielding from finally in async generators during
aclose() — doing so raises RuntimeError on client disconnect. Move
yields after try/finally where they work on normal completion and are
harmlessly unreachable on GeneratorExit.
2026-03-13 00:40:58 +07:00
Zamil Majdy
1d2f641a26 fix(copilot): move session.usage.append to finally block in SDK service
Ensures session usage persistence and rate-limit recording stay
consistent even when an exception interrupts the try block. Mirrors
the baseline service pattern.
2026-03-13 00:35:09 +07:00
Zamil Majdy
dcb71ab0b9 test(platform): add tests for usage endpoint and UsageLimits component
- 3 new tests for GET /usage endpoint (with/without session_id, config limits)
- 9 new tests for UsageLimits frontend component (loading, empty, rendering,
  percentage capping, session-only/weekly-only, link, hook args)
2026-03-13 00:30:52 +07:00
Zamil Majdy
8136b90860 fix(copilot): move StreamUsage/StreamFinish yields into finally block
On client disconnect, GeneratorExit terminates the generator after the
finally block, making yields after it unreachable. Moving them inside
finally ensures they are at least attempted.
2026-03-13 00:05:41 +07:00
Zamil Majdy
4d179a7c37 Merge branch 'dev' into feat/tracking-cost-block 2026-03-12 23:56:38 +07:00
Zamil Majdy
f78adcdc65 fix(platform): address all open review items — clock skew, token recording, generated hooks
- Fix clock skew: share single `now` timestamp across `_weekly_key` and
  `_weekly_reset_time` calls to prevent ISO week boundary race condition
- SDK: move `record_token_usage` to finally block so tokens are always
  recorded even when exceptions interrupt the stream (prevents rate limit bypass)
- Baseline: wrap `record_token_usage` in try/except so it cannot block
  session persistence
- Routes: pass `session_id=None` instead of empty string to avoid
  malformed Redis keys; `get_usage_status` now skips session lookup when
  session_id is None
- Tests: add partial None counter tests and no-session-id test
- Frontend: replace raw fetch with generated Orval hook
  (`useGetV2GetCopilotUsage`), use generated `CoPilotUsageStatus` type,
  fix `Date` vs `string` type for `resets_at`
2026-03-12 23:18:31 +07:00
Zamil Majdy
40388b7520 fix(platform): address review — orphan keys, TTL gather, dark mode, lazy logging
- Guard record_token_usage with user_id check to prevent orphan Redis keys
- Fold session TTL lookup into asyncio.gather (eliminate serial round-trip)
- Add dark: variants to UsageLimits component
- Use lazy %s formatting in logger calls instead of eager f-strings
2026-03-12 23:03:36 +07:00
Zamil Majdy
dd7be1158b test(backend): expand rate limit test coverage, fix token estimation null safety
- Add tests for: past resets_at (negative time), expired TTL fallback,
  _session_reset_from_ttl Redis error, pipeline TTL assertions,
  pipeline execute RedisError
- Fix null safety in baseline token estimation (m.get("content") or "")
- Use MagicMock for pipeline sync methods to eliminate coroutine warnings
2026-03-12 22:38:31 +07:00
Zamil Majdy
c0e59f0a6b fix(backend): address pushback on credit charging and token estimation
- Post-exec InsufficientBalanceError: return output + log warning
  instead of ErrorResponse (block already executed with side effects)
- Add token estimation fallback for providers that don't support
  stream_options include_usage (1 token ≈ 4 chars)
- Remove unnecessary # type: ignore on AsyncRedis param
2026-03-12 22:25:14 +07:00
Zamil Majdy
104d1f1bf4 fix(backend): revert to += for prompt tokens, add negative time guard
- Revert prompt token tracking back to += (sum all rounds): each API
  call bills independently, so total billed tokens is the correct
  metric for rate limiting
- Guard against negative reset time in RateLimitExceeded message
  (clock skew / stale Redis TTL)
2026-03-12 22:09:05 +07:00
Zamil Majdy
d9e9cd4c98 fix(backend): address human + Sentry review feedback on CoPilot tracking
- Fix double-counting of prompt tokens in multi-round tool calls:
  use = (not +=) since each API call's prompt_tokens includes full
  conversation history (both baseline and SDK paths)
- Fix docstring: "sliding-window" → "fixed-window" counters
- Type redis param as AsyncRedis instead of object
- Use pipeline(transaction=False) for independent INCRBY+EXPIRE
- Simplify days_until_monday calculation with `or 7`
- Flatten time_str into ternary, merge f-strings
- Parallelize Redis gets with asyncio.gather
- Document session_id=None behavior in usage endpoint
2026-03-12 21:54:31 +07:00
Zamil Majdy
ca416300ec fix(backend): address second round CodeRabbit review feedback
- Narrow except blocks to (RedisError, ConnectionError, OSError) instead
  of bare Exception to avoid hiding coding bugs
- Remove raw user_id from log messages to prevent PII leaks under Redis
  outages
- Reuse credit_model instead of fetching twice in execute_block()
- Treat post-execution InsufficientBalanceError as fatal (matches
  executor behavior) instead of silently swallowing it
2026-03-12 21:37:32 +07:00
Zamil Majdy
c589cd0c43 fix(backend): address CodeRabbit review feedback
- Derive session reset time from Redis TTL instead of hardcoded 3h
- Add description to UsageWindow.limit documenting 0 = unlimited
- Compare balance < cost instead of balance <= 0 in pre-exec check
- Document TOCTOU behavior in check_rate_limit docstring
2026-03-12 21:10:52 +07:00
Krzysztof Czerwinski
bc6b82218a feat(platform): add autopilot notification system (#12364)
Adds a notification system for the Copilot (AutoPilot) so users know
when background chats finish processing — via in-app indicators, sounds,
browser notifications, and document title badges.

### Changes 🏗️

**Backend**
- Add `is_processing` field to `SessionSummaryResponse` — batch-checks
Redis for active stream status on each session in the list endpoint
- Fix `is_processing` always returning `false` due to bytes vs string
comparison (`b"running"` → `"running"`) with `decode_responses=True`
Redis client
- Add `CopilotCompletionPayload` model for WebSocket notification events
- Publish `copilot_completion` notification via WebSocket when a session
completes in `stream_registry.mark_session_completed`

**Frontend — Notification UI**
- Add `NotificationBanner` component — amber banner prompting users to
enable browser notifications (auto-hides when already enabled or
dismissed)
- Add `NotificationDialog` component — modal dialog for enabling
notifications, supports force-open from sidebar menu for testing
- Fix repeated word "response" in dialog copy

**Frontend — Sidebar**
- Add bell icon in sidebar header with popover menu containing:
- Notifications toggle (requests browser permission on enable; shows
toast if denied)
  - Sound toggle (disabled when notifications are off)
  - "Show notification popup" button (for testing the dialog)
  - "Clear local data" button (resets all copilot localStorage keys)
- Bell icon states: `BellSlash` (disabled), `Bell` (enabled, no sound),
`BellRinging` (enabled + sound)
- Add processing indicator (PulseLoader) and completion checkmark
(CheckCircle) inline with chat title, to the left of the hamburger menu
- Processing indicator hides immediately when completion arrives (no
overlap with checkmark)
- Fix PulseLoader initial flash — start at `scale(0); opacity: 0` with
smoother keyframes
- Add 10s polling (`refetchInterval`) to session list so `is_processing`
updates automatically
- Clear document title badge when navigating to a completed chat
- Remove duplicate "Your chats" heading that appeared in both
SidebarHeader and SidebarContent

**Frontend — Notification Hook (`useCopilotNotifications`)**
- Listen for `copilot_completion` WebSocket events
- Track completed sessions in Zustand store
- Play notification sound (only for background sessions, not active
chat)
- Update `document.title` with unread count badge
- Send browser `Notification` when tab is hidden, with click-to-navigate
to the completed chat
- Reset document title on tab focus

**Frontend — Store & Storage**
- Add `completedSessionIDs`, `isNotificationsEnabled`, `isSoundEnabled`,
`showNotificationDialog`, `clearCopilotLocalData` to Zustand store
- Persist notification and sound preferences in localStorage
- On init, validate `isNotificationsEnabled` against actual
`Notification.permission`
- Add localStorage keys: `COPILOT_NOTIFICATIONS_ENABLED`,
`COPILOT_SOUND_ENABLED`, `COPILOT_NOTIFICATION_BANNER_DISMISSED`,
`COPILOT_NOTIFICATION_DIALOG_DISMISSED`

**Mobile**
- Add processing/completion indicators and sound toggle to MobileDrawer

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open copilot, start a chat, switch to another chat — verify
processing indicator appears on the background chat
- [x] Wait for background chat to complete — verify checkmark appears,
processing indicator disappears
- [x] Enable notifications via bell menu — verify browser permission
prompt appears
- [x] With notifications enabled, complete a background chat while on
another tab — verify system notification appears with sound
- [x] Click system notification — verify it navigates to the completed
chat
- [x] Verify document title shows unread count and resets when
navigating to the chat or focusing the tab
  - [x] Toggle sound off — verify no sound plays on completion
- [x] Toggle notifications off — verify no sound, no system
notification, no badge
  - [x] Clear local data — verify all preferences reset
- [x] Verify notification banner hides when notifications already
enabled
- [x] Verify dialog auto-shows for first-time users and can be
force-opened from menu

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 14:03:24 +00:00
Otto
83e49f71cd fix(frontend): pass through Supabase error params in password reset callback (#12384)
When Supabase rejects a password reset token (expired, already used,
etc.), it redirects to the callback URL with `error`, `error_code`, and
`error_description` params instead of a `code`. Previously, the callback
only checked for `!code` and returned a generic "Missing verification
code" error, swallowing the actual Supabase error.

This meant the `ExpiredLinkMessage` UX (added in SECRT-1369 / #12123)
was never triggered for these cases — users just saw the email input
form again with no explanation.

Now the callback reads Supabase's error params and forwards them to
`/reset-password`, where the existing expired link detection picks them
up correctly.

**Note:** This doesn't fix the root cause of Pwuts's token expiry issue
(likely link preview/prefetch consuming the OTP), but it ensures users
see the proper "link expired" message with a "Request new link" button
instead of a confusing silent redirect.

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-12 13:51:15 +00:00
Zamil Majdy
b6d863fcd2 feat(platform): add CoPilot credit charging, token tracking, and rate limiting
- Charge credits for block execution in CoPilot (matching graph executor behavior)
- Track LLM token usage for both SDK (Claude) and baseline (OpenAI) paths
- Add Redis-based per-user rate limiting with session and weekly token windows
- Expose usage status via GET /api/chat/usage endpoint
- Add frontend UsageLimits component with progress bars in CoPilot header
- Include unit tests for rate limiting and block credit charging
2026-03-12 20:50:21 +07:00
Bently
ef446e4fe9 feat(llm): Add Cohere Command A Family Models (#12339)
## Summary
Adds the Cohere Command A family of models to AutoGPT Platform with
proper pricing configuration.

## Models Added
- **Command A 03.2025**: Flagship model (256k context, 8k output) - 3
credits
- **Command A Translate 08.2025**: State-of-the-art translation (8k
context, 8k output) - 3 credits
- **Command A Reasoning 08.2025**: First reasoning model (256k context,
32k output) - 6 credits
- **Command A Vision 07.2025**: First vision-capable model (128k
context, 8k output) - 3 credits

## Changes
- Added 4 new LlmModel enum entries with proper OpenRouter model IDs
- Added ModelMetadata for each model with correct context windows,
output limits, and price tiers
- Added pricing configuration in block_cost_config.py

## Testing
- [ ] Models appear in AutoGPT Platform model selector
- [ ] Pricing is correctly applied when using models

Resolves **SECRT-2083**
2026-03-12 11:56:30 +00:00
Bently
7b1e8ed786 feat(llm): Add Microsoft Phi-4 model support (#12342)
## Changes
- Added `MICROSOFT_PHI_4` to LlmModel enum (`microsoft/phi-4`)
- Configured model metadata:
  - 16K context window
  - 16K max output tokens
  - OpenRouter provider
- Set cost tier: 1
  - Input: $0.06 per 1M tokens
  - Output: $0.14 per 1M tokens

## Details
Microsoft Phi-4 is a 14B parameter model available through OpenRouter.
This PR adds proper support in the autogpt_platform backend.

Resolves SECRT-2086
2026-03-12 11:15:27 +00:00
Abhimanyu Yadav
7ccfff1040 feat(frontend): add credential type selector for multi-auth providers (#12378)
### Changes

- When a provider supports multiple credential types (e.g. GitHub with
both OAuth and API Key),
clicking "Add credential" now opens a tabbed dialog where users can
choose which type to use.
  Previously, OAuth always took priority and API key was unreachable.
- Each credential in the list now shows a type-specific icon (provider
icon for OAuth, key for API Key,
password/lock for others) and a small label badge (e.g. "API Key",
"OAuth").
- The native dropdown options also include the credential type in
parentheses for clarity.
- Single credential type providers behave exactly as before — no dialog,
direct action.


https://github.com/user-attachments/assets/79f3a097-ea97-426b-a2d9-781d7dcdb8a4



  ## Test plan
- [x] Test with a provider that has only one credential type (e.g.
OpenAI with api_key only) — should
  behave as before
- [x] Test with a provider that has multiple types (e.g. GitHub with
OAuth + API Key configured) —
  should show tabbed dialog
  - [x] Verify OAuth tab triggers the OAuth flow correctly
  - [x] Verify API Key tab shows the inline form and creates credentials
  - [x] Verify credential list shows correct icons and type badges
  - [x] Verify dropdown options show type in parentheses
2026-03-12 10:17:58 +00:00
Otto
81c7685a82 fix(frontend): release test fixes — scheduler time picker, unpublished banner (#12376)
Two frontend fixes from release testing (2026-03-11):

**SECRT-2102:** The schedule dialog shows an "At [hh]:[mm]" time picker
when selecting Custom > Every x Minutes or Hours, which makes no sense
for sub-day intervals. Now only shows the time picker for Custom > Days
and other frequency types.

**SECRT-2103:** The "Unpublished changes" banner shows for agents the
user doesn't own or create. Root cause: `owner_user_id` is the library
copy owner, not the graph creator. Changed to use `can_access_graph`
which correctly reflects write access.

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>

---------

Co-authored-by: Reinier van der Leer (@Pwuts) <reinier@agpt.co>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-12 10:02:26 +00:00
Bently
3595c6e769 feat(llm): add Perplexity Sonar Reasoning Pro model (#12341)
## Summary
Adds support for Perplexity's new reasoning model:
`perplexity/sonar-reasoning-pro`

## Changes
-  Added `PERPLEXITY_SONAR_REASONING_PRO` to `LlmModel` enum
-  Added model metadata (128K context window, 8K max output tokens,
tier 2)
-  Set pricing at 5 credits (matches sonar-pro tier)

## Model Details
- **Model ID:** `perplexity/sonar-reasoning-pro`
- **Provider:** OpenRouter
- **Context Window:** 128,000 tokens
- **Max Output:** 8,000 tokens
- **Pricing:** $0.000002/token (prompt), $0.000008/token (completion)
- **Cost Tier:** 2 (5 credits)

## Testing
-  Black formatting passed
-  Ruff linting passed

Resolves SECRT-2084
2026-03-12 09:58:29 +00:00
Abhimanyu Yadav
1c2953d61b fix(frontend): restore broken tutorial in builder (#12377)
### Changes
- Restored missing `shepherd.js/dist/css/shepherd.css` base styles
import
- Added missing .new-builder-tutorial-disable and
.new-builder-tutorial-highlight CSS classes to
  tutorial.css
- Fixed getFormContainerSelector() to include -node suffix matching the
actual DOM attribute

###  What broke
The old legacy-builder/tutorial.ts was the only file importing
Shepherd's base CSS. When #12082 removed
the legacy builder, the new tutorial lost all base Shepherd styles
(close button positioning, modal
overlay, tooltip layout). The new tutorial's custom CSS overrides
depended on these base styles
  existing.

  Test plan
  - [x] Start the tutorial from the builder (click the chalkboard icon)
- [x] Verify the close (X) button is positioned correctly in the
top-right of the popover
  - [x] Verify the modal overlay dims the background properly
- [x] Verify element highlighting works when the tutorial points to
blocks/buttons
- [x] Verify non-target blocks are grayed out during the "select
calculator" step
- [x] Complete the full tutorial flow end-to-end (add block → configure
→ connect → save → run)
2026-03-12 09:23:34 +00:00
Zamil Majdy
755bc84b1a fix(copilot): replace MCP jargon with user-friendly language (#12381)
Closes SECRT-2105

### Changes 🏗️

Replace all user-facing MCP technical terminology with plain, friendly
language across the CoPilot UI and LLM prompting.

**Backend (`run_mcp_tool.py`)**
- Added `_service_name()` helper that extracts a readable name from an
MCP host (`mcp.sentry.dev` → `Sentry`)
- `agent_name` in `SetupRequirementsResponse`: `"MCP: mcp.sentry.dev"` →
`"Sentry"`
- Auth message: `"The MCP server at X requires authentication. Please
connect your credentials to continue."` → `"To continue, sign in to
Sentry and approve access."`

**Backend (`mcp_tool_guide.md`)**
- Added "Communication style" section with before/after examples to
teach the LLM to avoid "MCP server", "OAuth", "credentials" jargon in
responses to users

**Frontend (`MCPSetupCard.tsx`)**
- Button: `"Connect to mcp.sentry.dev"` → `"Connect Sentry"`
- Connected state: `"Connected to mcp.sentry.dev!"` → `"Connected to
Sentry!"`
- Retry message: `"I've connected the MCP server credentials. Please
retry."` → `"I've connected. Please retry."`

**Frontend (`helpers.tsx`)**
- Added `serviceNameFromHost()` helper (exported, mirrors the backend
logic)
- Run text: `"Discovering MCP tools on mcp.sentry.dev"` → `"Connecting
to Sentry…"`
- Run text: `"Connecting to MCP server"` → `"Connecting…"`
- Run text: `"Connect to MCP: mcp.sentry.dev"` → `"Connect Sentry"`
(uses `agent_name` which is now just `"Sentry"`)
- Run text: `"Discovered N tool(s) on mcp.sentry.dev"` → `"Connected to
Sentry"`
- Error text: `"MCP error"` → `"Connection error"`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Open CoPilot and ask it to connect to a service (e.g. Sentry,
Notion)
- [ ] Verify the run text accordion title shows `"Connecting to
Sentry…"` instead of `"Discovering MCP tools on mcp.sentry.dev"`
- [ ] Verify the auth card button shows `"Connect Sentry"` instead of
`"Connect to mcp.sentry.dev"`
- [ ] Verify the connected state shows `"Connected to Sentry!"` instead
of `"Connected to mcp.sentry.dev!"`
- [ ] Verify the LLM response text avoids "MCP server", "OAuth",
"credentials" terminology
2026-03-12 08:54:15 +00:00
Bently
ade2baa58f feat(llm): Add Grok 3 model support (#12343)
## Summary
Adds support for xAI's Grok 3 model to AutoGPT.

## Changes
- Added `GROK_3` to `LlmModel` enum with identifier `x-ai/grok-3`
- Configured model metadata:
  - Context window: 131,072 tokens (128k)
  - Max output: 32,768 tokens (32k)  
  - Provider: OpenRouter
  - Creator: xAI
  - Price tier: 2 (mid-tier)
- Set model cost to 3 credits (mid-tier pricing between fast models and
Grok 4)
- Updated block documentation to include Grok 3 in model lists

## Pricing Rationale
- **Grok 4**: 9 credits (tier 3 - premium, 256k context)
- **Grok 3**: 3 credits (tier 2 - mid-tier, 128k context) ← NEW
- **Grok 4 Fast/4.1 Fast/Code Fast**: 1 credit (tier 1 - affordable)

Grok 3 is positioned as a mid-tier model, priced similarly to other tier
2 models.

## Testing
- [x] Code passes `black` formatting
- [x] Code passes `ruff` linting
- [x] Model metadata and cost configuration added
- [x] Documentation updated

Closes SECRT-2079
2026-03-12 07:31:59 +00:00
Reinier van der Leer
4d35534a89 Merge branch 'master' into dev 2026-03-11 22:26:35 +01:00
Zamil Majdy
2cc748f34c chore(frontend): remove accidentally committed generated file (#12373)
`responseType.ts` was accidentally committed inside
`src/app/api/__generated__/models/` despite that directory being listed
in `.gitignore` (added in PR #12238).

### Changes 🏗️

- Removes
`autogpt_platform/frontend/src/app/api/__generated__/models/responseType.ts`
from git tracking — the file is already covered by the `.gitignore` rule
`src/app/api/__generated__/` and should never have been committed.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] No functional changes — only removes a stale tracked file that is
already gitignored
2026-03-11 14:22:37 +00:00
Shunyu Wu
c2e79fa5e1 fix(gmail): fallback to raw HTML when html2text conversion fails (#12369)
## Summary
- keep Gmail body extraction resilient when `html2text` converter raises
- fallback to raw HTML instead of failing extraction
- add regression test for converter failure path

Closes #12368

## Testing
- added unit test in
`autogpt_platform/backend/test/blocks/test_gmail.py`

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-11 11:46:57 +00:00
Bently
89a5b3178a fix(llm): Update Gemini model lineup - add 3.1 models, deprecate 3 Pro Preview (#12331)
## 🔴 URGENT: Gemini 3 Pro Preview Shutdown - March 9, 2026

Google is shutting down Gemini 3 Pro Preview **tomorrow (March 9,
2026)**. This PR addresses SECRT-2067 by updating the Gemini model
lineup to prevent disruption.

---

## Changes

###  P0 - Critical (This Week)
- [x] **Remove/Replace Gemini 3 Pro Preview** → Migrated to 3.1 Pro
Preview
- [x] **Add Gemini 3.1 Pro Preview** (released Feb 19, 2026)

###  P1 - High Priority  
- [x] **Add Gemini 3.1 Flash Lite Preview** (released Mar 3, 2026)
- [x] **Add Gemini 3 Flash Preview** (released Dec 17, 2025)

###  P2 - Medium Priority
- [x] **Add Gemini 2.5 Pro (stable/GA)** (released Jun 17, 2025)

---

## Model Details

| Model | Context | Input Cost | Output Cost | Price Tier |
|-------|---------|------------|-------------|------------|
| **Gemini 3.1 Pro Preview** | 1.05M | $2.00/1M | $12.00/1M | 2 |
| **Gemini 3.1 Flash Lite Preview** | 1.05M | $0.25/1M | $1.50/1M | 1 |
| **Gemini 3 Flash Preview** | 1.05M | $0.50/1M | $3.00/1M | 1 |
| **Gemini 2.5 Pro (GA)** | 1.05M | $1.25/1M | $10.00/1M | 2 |
| ~~Gemini 3 Pro Preview~~ | ~~1.05M~~ | ~~$2.00/1M~~ | ~~$12.00/1M~~ |
**DEPRECATED** |

---

## Migration Strategy

**Database Migration:**
`20260308095500_migrate_deprecated_gemini_3_pro_preview`

- Automatically migrates all existing graphs using
`google/gemini-3-pro-preview` to `google/gemini-3.1-pro-preview`
- Updates: AgentBlock, AgentGraphExecution, AgentNodeExecution,
AgentGraph
- Zero user-facing disruption
- Migration runs on next deployment (before March 9 shutdown)

---

## Testing

- [ ] Verify new models appear in LLM block dropdown
- [ ] Test migration on staging database
- [ ] Confirm existing graphs using deprecated model auto-migrate
- [ ] Validate cost calculations for new models

---

## References

- **Linear Issue:**
[SECRT-2067](https://linear.app/autogpt/issue/SECRT-2067)
- **OpenRouter Models:** https://openrouter.ai/models/google
- **Google Deprecation Notice:**
https://ai.google.dev/gemini-api/docs/deprecations

---

## Checklist

- [x] Models added to `LlmModel` enum
- [x] Model metadata configured
- [x] Cost config updated
- [x] Database migration created
- [x] Deprecated model commented out (not removed for historical
reference)
- [ ] PR reviewed and approved
- [ ] Merged before March 9, 2026 deadline

---

**Priority:** 🔴 Critical - Must merge before March 9, 2026
2026-03-11 11:21:16 +00:00
Abhimanyu Yadav
c62d9a24ff fix(frontend): show correct status in agent submission view modal (#12360)
### Changes 🏗️
- The "View" modal for agent submissions hardcoded "Agent is awaiting
review" regardless of actual status
- Now displays "Agent approved", "Agent rejected", or "Agent is awaiting
review" based on the submission's actual status
- Shows review feedback in a highlighted section for rejected agents
when review comments are available

<img width="1127" height="788" alt="Screenshot 2026-03-11 at 9 02 29 AM"
src="https://github.com/user-attachments/assets/840e0fb1-22c2-4fda-891b-967c8b8b4043"
/>
<img width="1105" height="680" alt="Screenshot 2026-03-11 at 9 02 46 AM"
src="https://github.com/user-attachments/assets/f0c407e6-c58e-4ec8-9988-9f5c69bfa9a7"
/>

  ## Test plan
- [x] Submit an agent and verify the view modal shows "Agent is awaiting
review"
- [x] View an approved agent submission and verify it shows "Agent
approved"
- [x] View a rejected agent submission and verify it shows "Agent
rejected"
- [x] View a rejected agent with review comments and verify the feedback
section appears

  Closes SECRT-2092
2026-03-11 10:08:17 +00:00
Abhimanyu Yadav
0e0bfaac29 fix(frontend): show specific error messages for store image upload failures (#12361)
### Changes
- Surface backend error details (file size limit, invalid file type,
virus detected, etc.) in the upload failed toast instead of showing a
generic "Upload Failed" message
- The backend already returns specific error messages (e.g., "File too
large. Maximum size is 50MB") but the frontend was discarding them with
a catch-all handler
  
<img width="1222" height="411" alt="Screenshot 2026-03-11 at 9 13 30 AM"
src="https://github.com/user-attachments/assets/34ab3d90-fffa-4788-917a-fe2a7f4144b9"
/>

  ## Test plan
- [x] Upload an image larger than 50MB to a store submission → should
see "File too large. Maximum size is 50MB"
- [x] Upload an unsupported file type → should see file type error
message
  - [x] Upload a valid image → should still work normally

  Resolves SECRT-2093
2026-03-11 10:07:37 +00:00
Bently
0633475915 fix(frontend/library): graceful schedule deletion with auto-selection (#12278)
### Motivation 🎯

Fixes the issue where deleting a schedule shows an error screen instead
of gracefully handling the deletion. Previously, when a user deleted a
schedule, a race condition occurred where the query cache refetch
completed before the URL
state updated, causing the component to try rendering a schedule that no
longer existed (resulting in a 404 error screen).

### Changes 🏗️

**1. Fixed deletion order to prevent error screen flash**
- `useSelectedScheduleActions.ts` - Call `onDeleted()` callback
**before** invalidating queries to clear selection first
- `ScheduleActionsDropdown.tsx` - Same fix for sidebar dropdown deletion

**2. Added smart auto-selection logic**
- `useNewAgentLibraryView.ts`:
  - Added query to fetch current schedules list
  - Added `handleScheduleDeleted(deletedScheduleId)` function that:
    - Auto-selects the first remaining schedule if others exist
    - Clears selection to show empty state if no schedules remain

**3. Wired up smart deletion handler throughout component tree**
- `NewAgentLibraryView.tsx` - Passes `handleScheduleDeleted` to child
components
- `SelectedScheduleView.tsx` - Changed callback from
`onClearSelectedRun` to `onScheduleDeleted` and passes schedule ID
- `SidebarRunsList.tsx` - Added `onScheduleDeleted` prop and passes it
through to list items

### Checklist 📋

**Test Plan:**
- [] Create 2-3 test schedules for an agent
- [] Delete a schedule from the detail view (trash icon in actions) when
other schedules exist → Verify next schedule auto-selects without error
- [] Delete a schedule from the sidebar dropdown (three-dot menu) when
other schedules exist → Verify next schedule auto-selects without error
- [] Delete all schedules until only one remains → Verify empty state
shows gracefully without error
- [] Verify "Schedule deleted" toast appears on successful deletion
- [] Verify no error screen appears at any point during deletion flow
2026-03-11 09:01:55 +00:00
Bently
34a2f9a0a2 feat(llm): add Mistral flagship models (Large 3, Medium 3.1, Small 3.2, Codestral) (#12337)
## Summary

Adds four missing Mistral AI flagship models to address the critical
coverage gap identified in
[SECRT-2082](https://linear.app/autogpt/issue/SECRT-2082).

## Models Added

| Model | Context | Max Output | Price Tier | Use Case |
|-------|---------|------------|------------|----------|
| **Mistral Large 3** | 262K | None | 2 (Medium) | Flagship reasoning
model, 41B active params (675B total), MoE architecture |
| **Mistral Medium 3.1** | 131K | None | 2 (Medium) | Balanced
performance/cost, 8x cheaper than traditional large models |
| **Mistral Small 3.2** | 131K | 131K | 1 (Low) | Fast, cost-efficient,
high-volume use cases |
| **Codestral 2508** | 256K | None | 1 (Low) | Code generation
specialist (FIM, correction, test gen) |

## Problem

Previously, the platform only offered:
- Mistral Nemo (1 official model)
- dolphin-mistral (third-party Ollama fine-tune)

This left significant gaps in Mistral's lineup, particularly:
- No flagship reasoning model
- No balanced mid-tier option
- No code-specialized model
- Missing multimodal capabilities (Large 3, Medium 3.1, Small 3.2 all
support text+image)

## Changes

**File:** `autogpt_platform/backend/backend/blocks/llm.py`

- Added 4 enum entries in `LlmModel` class
- Added 4 metadata entries in `MODEL_METADATA` dict
- All models use OpenRouter provider
- Follows existing pattern for model additions

## Testing

-  Enum values match OpenRouter model IDs
-  Metadata follows existing format
-  Context windows verified from OpenRouter API
-  Price tiers assigned appropriately

## Closes

- SECRT-2082

---

**Note:** All models are available via OpenRouter and tested. This
brings Mistral coverage in line with other major providers (OpenAI,
Anthropic, Google).
2026-03-11 08:48:48 +00:00
Zamil Majdy
9f4caa7dfc feat(blocks): add and harden GitHub blocks for full-cycle development (#12334)
## Summary
- Add 8 new GitHub blocks: GetRepositoryInfo, ForkRepository,
ListCommits, SearchCode, CompareBranches, GetRepositoryTree,
MultiFileCommit, MergePullRequest
- Split `repo.py` (2094 lines, 19 blocks) into domain-specific modules:
`repo.py`, `repo_branches.py`, `repo_files.py`, `commits.py`
- Concurrent blob creation via `asyncio.gather()` in MultiFileCommit
- URL-encode branch/ref params via `urllib.parse.quote()` for
defense-in-depth
- Step-level error handling in MultiFileCommit ref update with recovery
SHA
- Collapse FileOperation CREATE/UPDATE into UPSERT (Git Trees API treats
them identically)
- Add `ge=1, le=100` constraints on per_page SchemaFields
- Preserve URL scheme in `prepare_pr_api_url`
- Handle null commit authors gracefully in ListCommits
- Add unit tests for `prepare_pr_api_url`, error-path tests for
MergePR/MultiFileCommit, FileOperation enum validation tests

## Test plan
- [ ] Block tests pass for all 19 GitHub blocks (CI:
`test_available_blocks`)
- [ ] New test file `test_github_blocks.py` passes (prepare_pr_api_url,
error paths, enum)
- [ ] `check-docs-sync` passes with regenerated docs
- [ ] pyright/ruff clean on all changed files
2026-03-11 08:35:37 +00:00
Otto
0876d22e22 feat(frontend/copilot): improve TTS voice selection to avoid robotic voices (#12317)
Requested by @0ubbe

Refines the `pickBestVoice()` function to ensure non-robotic voices are
always preferred:

- **Filter out known low-quality engines** — eSpeak, Festival, MBROLA,
Flite, and Pico voices are deprioritized
- **Prefer remote/cloud-backed voices** — `localService: false` voices
are typically higher quality
- **Expand preferred voices list** — added Moira, Tessa (macOS), Jenny,
Aria, Guy (Windows OneCore)
- **Smarter fallback chain** — English default → English → any default →
first available

The previous fallback could select eSpeak or Festival voices on Linux
systems, resulting in robotic output. Now those are filtered out unless
they're the only option.

---
Co-authored-by: Ubbe <ubbe@users.noreply.github.com>

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 16:47:42 +08:00
Zamil Majdy
15e3980d65 fix(frontend): buffer workspace file downloads to prevent truncation (#12349)
## Summary
- Workspace file downloads (images, CSVs, etc.) were silently truncated
(~10 KB lost from the end) when served through the Next.js proxy
- Root cause: `new NextResponse(response.body)` passes a
`ReadableStream` directly, which Next.js / Vercel silently truncates for
larger files
- Fix: fully buffer with `response.arrayBuffer()` before forwarding, and
set `Content-Length` from the actual buffer size
- Keeps the auth proxy intact — no signed URLs (which would be public
and expire, breaking chat history)

## Root cause verification
Confirmed locally on session `080f27f9-0379-4085-a67a-ee34cc40cd62`:
- Backend `write_workspace_file` logs **978,831 bytes** written
- Direct backend download (`curl
localhost:8006/api/workspace/files/.../download`): **978,831 bytes** 
- Browser download through Next.js proxy: **truncated** 

## Why not signed URLs?
- Signed URLs are effectively public — anyone with the link can download
the file (privacy concern)
- Signed URLs expire, but chat history persists — reopening a
conversation later would show broken downloads
- Buffering is fine: workspace files are capped at 100 MB, Vercel
function memory is 1 GB

## Related
- Discord thread: `#Truncated File Bug` channel
- Related PR #12319 (signed URL approach) — this fix is simpler and
preserves auth

## Test plan
- [ ] Download a workspace file (CSV, PNG, any type) through the copilot
UI
- [ ] Verify downloaded file size matches the original
- [ ] Verify PNGs open correctly and CSVs have all rows

cc @Swiftyos @uberdot @AdarshRawat1
2026-03-10 18:23:51 +00:00
Zamil Majdy
fe9eb2564b feat(copilot): HITL review for sensitive block execution (#12356)
## Summary
- Integrates existing Human-In-The-Loop (HITL) review infrastructure
into CoPilot's direct block execution (`run_block`) for blocks marked
with `is_sensitive_action=True`
- Removes the `PendingHumanReview → AgentGraphExecution` FK constraint
to support synthetic CoPilot session IDs (migration included)
- Adds `ReviewRequiredResponse` model + frontend `ReviewRequiredCard`
component to surface review status in the chat UI
- Auto-approval works within a CoPilot session: once a block is
approved, subsequent executions of the same block in the same session
are auto-approved (using `copilot-session-{session_id}` as
`graph_exec_id` and `copilot-node-{block_id}` as `node_id`)

## Test plan
- [x] All 11 `run_block_test.py` tests pass (3 new sensitive action
tests)
- [ ] Manual: Execute a block with `is_sensitive_action=True` in CoPilot
→ verify ReviewRequiredResponse is returned and rendered
- [ ] Manual: Approve in review panel → re-execute the same block →
verify auto-approval kicks in
- [ ] Manual: Verify non-sensitive blocks still execute without review
2026-03-10 18:20:11 +00:00
Otto
5641cdd3ca fix(backend): update test patches for validate_url → validate_url_host rename (#12358)
bfb843a renamed `validate_url` to `validate_url_host` in
`agent_browser`, `run_mcp_tool`, and MCP routes, but the corresponding
test files still patched the old name, causing `AttributeError` in CI.

Updates all mock patch targets and assertions across 3 test files:
- `agent_browser_test.py`
- `test_run_mcp_tool.py`  
- `mcp/test_routes.py`

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-10 17:22:11 +00:00
Otto
bfb843a56e Merge commit from fork
* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-10 15:51:58 +01:00
Abhimanyu Yadav
684845d946 fix(frontend/builder): handle discriminated unions and improve node layout (#12354)
## Summary
- **Discriminated union support (oneOf)**: Added a new `OneOfField`
component that properly
renders Pydantic discriminated unions. Hides the unusable parent object
handle, auto-populates
the discriminator value, shows a dropdown with variant titles (e.g.,
"Username" / "UserId"), and
filters out the internal discriminator field from the form.
Non-discriminated `oneOf` schemas
  fall back to existing `AnyOfField` behavior.
- **Collapsible object outputs**: Object-type outputs with nested keys
(e.g.,
`PersonLookupResponse.Url`, `PersonLookupResponse.profile`) are now
collapsed by default behind a
caret toggle. Nested keys show short names instead of the full
`Parent.Key` prefix.
- **Node layout cleanup**: Removed excessive bottom margin (`mb-6`) from
`FormRenderer`, hide the
Advanced toggle when no advanced fields exist, and add rounded bottom
corners on OUTPUT-type
  blocks.

<img width="440" height="427" alt="Screenshot 2026-03-10 at 11 31 55 AM"
src="https://github.com/user-attachments/assets/06cc5414-4e02-4371-bdeb-1695e7cb2c97"
/>
<img width="371" height="320" alt="Screenshot 2026-03-10 at 11 36 52 AM"
src="https://github.com/user-attachments/assets/1a55f87a-c602-4f4d-b91b-6e49f810e5d5"
/>

  ## Test plan
- [x] Add a Twitter Get User block — verify "Identifier" shows a
dropdown (Username/UserId) with
no unusable parent handle, discriminator field is hidden, and the block
can run without staying
  INCOMPLETE
- [x] Add any block with object outputs (e.g., PersonLookupResponse) —
verify nested keys are
  collapsed by default and expand on click with short labels
- [x] Verify blocks without advanced fields don't show the Advanced
toggle
- [x] Verify existing `anyOf` schemas (optional types, 3+ variant
unions) still render correctly
  - [x] Check OUTPUT-type blocks have rounded bottom corners

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: eureka928 <meobius123@gmail.com>
2026-03-10 14:13:32 +00:00
Bently
6a6b23c2e1 fix(frontend): Remove unused Otto Server Action causing 107K+ errors (#12336)
## Summary

Fixes [OPEN-3025](https://linear.app/autogpt/issue/OPEN-3025) —
**107,571+ Server Action errors** in production

Removes the orphaned `askOtto` Server Action that was left behind after
the Otto chat widget removal in PR #12082.

## Problem

Next.js Server Actions that are never imported are excluded from the
server manifest. Old client bundles still reference the action ID,
causing "not found" errors.

**Sentry impact:**
- **BUILDER-3BN:** 107,571 events
- **BUILDER-729:** 285 events  
- **BUILDER-3QH:** 1,611 events
- **36+ users affected**

## Root Cause

1. **Mar 2025:** Otto widget added to `/build` page with `askOtto`
Server Action
2. **Feb 2026:** Otto widget removed (PR #12082), but `actions.ts` left
behind
3. **Result:** Dead code → not in manifest → errors

## Evidence

```bash
# Zero imports across frontend:
grep -r "askOtto" src/ --exclude="actions.ts"
# → No results

# Server manifest missing the action:
cat .next/server/server-reference-manifest.json
# → Only includes login/supabase actions, NOT build/actions
```

## Changes

-  Delete
`autogpt_platform/frontend/src/app/(platform)/build/actions.ts`

## Testing

1. Verify no imports of `askOtto` in codebase 
2. Check Sentry for error drop after deploy
3. Monitor for new "Server Action not found" errors

## Checklist

- [x] Dead code confirmed (zero imports)
- [x] Sentry issues documented
- [x] Clear commit message with context
2026-03-10 09:03:38 +00:00
Dream
d0a1d72e8a fix(frontend/builder): batch undo history for cascading operations (#12344)
## Summary

Fixes undo in the Builder not working correctly when deleting nodes.
When a node is deleted, React Flow fires `onNodesChange` (node removal)
and `onEdgesChange` (cascading edge cleanup) as separate callbacks —
each independently pushing to the undo history stack. This creates
intermediate states that break undo:

- Single undo restores a partial state (e.g. edges pointing to a deleted
node)
- Multiple undos required to fully restore the graph
- Redo also produces inconsistent states

Resolves #10999

### Changes 🏗️

- **`historyStore.ts`** — Added microtask-based batching to
`pushState()`. Multiple calls within the same synchronous execution
(same event loop tick) are coalesced into a single history entry,
keeping only the first pre-change snapshot. Uses `queueMicrotask` so all
cascading store updates from a single user action settle before the
history entry is committed.
- Reset `pendingState` in `initializeHistory()` and `clear()` to prevent
stale batched state from leaking across graph loads or navigation.

**Side benefit:** Copy/paste operations that add multiple nodes and
edges now also produce a single history entry instead of one per
node/edge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Place 3 blocks A, B, C and connect A→B→C
  - [x] Delete block C (removes node + cascading edge B→C)
  - [x] Delete connection A→B
  - [x] Undo — connection A→B restored (single undo, not multiple)
  - [x] Undo — block C and connection B→C restored
  - [x] Redo — block C removed again with its connections
- [x] Copy/paste multiple connected blocks — single undo reverts entire
paste

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-10 04:55:07 +00:00
Zamil Majdy
f1945d6a2f feat(platform/copilot): @@agptfile: file-ref protocol for tool call inputs + block input toggle (#12332)
## Summary

- **Problem**: When the LLM calls a tool with large file content, it
must rewrite all content token-by-token. This is wasteful since the
files are already accessible on disk.
- **Solution**: Introduces an \`@@agptfile:\` reference protocol. The
LLM passes a file path reference; the processor loads and substitutes
the content before executing the tool.

### Protocol

\`\`\`
@@agptfile:<uri>[<start>-<end>]
\`\`\`

**Supported URI types:**
| URI | Source |
|-----|--------|
| \`workspace://<file_id>\` | Persistent workspace file by ID |
| \`workspace:///<path>\` | Workspace file by virtual path |
| \`/absolute/path\` | Absolute host or sandbox path |

**Line range** is optional; omitting it reads the whole file.

### Backend changes

- Rename \`@file:\` → \`@@agptfile:\` prefix for uniqueness; extract
\`FILE_REF_PREFIX\` constant
- Extract shared execution-context ContextVars into
\`backend/copilot/context.py\` — eliminates duplicate ContextVar objects
that caused \`e2b_file_tools.py\` to always see empty context
- \`tool_adapter.py\` imports ContextVars from \`context.py\` (single
source of truth)
- \`expand_file_refs_in_string\` raises \`FileRefExpansionError\` on
failure (instead of inline error strings), blocking tool execution and
returning a clear error hint to the model
- Tighten URI regex: only expand refs starting with \`workspace://\` or
\`/\`
- Aggregate budget: 1 MB total expansion cap across all refs in one
string
- Per-file cap: 200 KB per individual ref
- Fix \`_read_file_handler\` to pass \`get_sdk_cwd()\` to
\`is_allowed_local_path\` — ephemeral working directory files were
incorrectly blocked
- Fix \`_is_allowed_local\` in \`e2b_file_tools.py\` to pass
\`get_sdk_cwd()\`
- Restrict local path allow-list to \`tool-results/\` subdirectory only
(was entire session project dir)
- Add \`raise_on_error\` param + remove two-pass \`_FILE_REF_ERROR_RE\`
detection
- Update system prompt docs and tool_adapter error messages

### Frontend changes

- \`BlockInputCard\`: hidden by default with Show/Hide toggle + \`mb-2\`
spacing

## Test plan

- [ ] \`poetry run pytest backend/copilot/ -x
--ignore=backend/copilot/sdk/file_ref_integration_test.py\` passes
- [ ] \`@@agptfile:workspace:///<path>[1-50]\` expands correctly in tool
calls
- [ ] Invalid line ranges produce \`[file-ref error: ...]\` inline
messages
- [ ] Files outside \`sdk_cwd\` / \`tool-results/\` are rejected
- [ ] Block input card shows hidden by default with toggle
2026-03-09 18:39:13 +00:00
Zamil Majdy
6491cb1e23 feat(copilot): local agent generation with validation, fixing, MCP & sub-agent support (#12238)
## Summary

Port the agent generation pipeline from the external AgentGenerator
service into local copilot tools, making the Claude Agent SDK itself
handle validation, fixing, and block recommendation — no separate inner
LLM calls needed.

Key capabilities:
- **Local agent generation**: Create, edit, and customize agents
entirely within the SDK session
- **Graph validation**: 9 validation checks (block existence, link
references, type compatibility, IO blocks, etc.)
- **Graph fixing**: 17+ auto-fix methods (ID repair, link rewiring, type
conversion, credential stripping, dynamic block sink names, etc.)
- **MCP tool blocks**: Guide and fixer support for MCPToolBlock nodes
with proper dynamic input schema handling
- **Sub-agent composition**: AgentExecutorBlock support with library
agent schema enrichment
- **Embedding fallback**: Falls back to OpenRouter for embeddings when
`openai_internal_api_key` is unavailable
- **Actionable error messages**: Excluded block types (MCP, Agent)
return specific hints redirecting to the correct tool

### New Tools
- `validate_agent_graph` — run 9 validation checks on agent JSON
- `fix_agent_graph` — apply 17+ auto-fixes to agent JSON
- `get_blocks_for_goal` — recommend blocks for a given goal (with
optimized descriptions)

### Refactored Tools
- `create_agent`, `edit_agent`, `customize_agent` — accept `agent_json`
for local generation with shared fix→validate→save pipeline
- `find_block` — added `include_schemas` parameter, excludes MCP/Agent
blocks with actionable hints
- `run_block` — actionable error messages for excluded block types
- `find_library_agent` — enriched with `graph_version`, `input_schema`,
`output_schema` for sub-agent composition

### Architecture
- Split 2,558-line `validation.py` into `fixer.py`, `validator.py`,
`helpers.py`, `pipeline.py`
- Extracted shared `fix_validate_and_save()` pipeline (was duplicated
across 3 tools)
- Shared `OPENROUTER_BASE_URL` constant across codebase
- Comprehensive test coverage: 78+ unit tests for fixer/validator, 8
run_block tests, 17 SDK compat tests

## Test plan
- [x] `poetry run format` passes
- [x] `poetry run pytest -s -vvv backend/copilot/` — all tests pass
- [x] CI green on all Python versions (3.11, 3.12, 3.13)
- [x] Manual E2E: copilot generates agents with correct IO blocks,
links, and node structure
- [x] Manual E2E: MCP tool blocks use bare field names for dynamic
inputs
- [x] Manual E2E: sub-agent composition with AgentExecutorBlock
2026-03-09 16:10:22 +00:00
nKOxxx
c7124a5240 Add documentation for Google Gemini integration (#12283)
## Summary
Adding comprehensive documentation for Google Gemini integration with
AutoGPT.

## Changes
- Added setup instructions for Gemini API
- Documented configuration options
- Added examples and best practices

## Related Issues
N/A - Documentation improvement

## Testing
- Verified documentation accuracy
- Tested all code examples

## Checklist
- [x] Code follows project style
- [x] Documentation updated
- [x] Tests pass (if applicable)
2026-03-09 15:13:28 +00:00
Zamil Majdy
5537cb2858 dx: add shared Claude Code skills as auto-triggered guidelines (#12297)
## Summary
- Add 8 Claude Code skills under \`.claude/skills/\` that act as
**auto-triggered guidelines** — the LLM invokes them automatically based
on context, no manual \`/command\` needed
- Skills: \`pr-review\`, \`pr-create\`, \`new-block\`,
\`openapi-regen\`, \`backend-check\`, \`frontend-check\`,
\`worktree-setup\`, \`code-style\`
- Each skill has an explicit TRIGGER condition so the LLM knows when to
apply it without being asked

## Changes

### Skills (all auto-triggered by context)
| Skill | Trigger |
|-------|---------|
| \`pr-review\` | User shares a PR URL or asks to address review
comments |
| \`pr-create\` | User asks to create a PR, push changes for review, or
submit work |
| \`new-block\` | User asks to create a new block or add a new
integration |
| \`openapi-regen\` | API routes change, new endpoints added, or
frontend types are stale |
| \`backend-check\` | Backend Python code has been modified |
| \`frontend-check\` | Frontend TypeScript/React code has been modified
|
| \`worktree-setup\` | User asks to work on a branch in isolation or set
up a worktree |
| \`code-style\` | Writing or reviewing Python code |

## Test plan
- [ ] Verify skills appear automatically in Claude Code when context
matches (no \`/command\` needed)
- [ ] Modify frontend code — confirm \`frontend-check\` fires
automatically
- [ ] Ask Claude to "create a PR" — confirm \`pr-create\` fires without
\`/pr-create\`
- [ ] Share a PR URL — confirm \`pr-review\` fires automatically

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 15:10:38 +00:00
Zamil Majdy
aef5f6d666 feat(copilot): E2B sandbox auto-pause between turns to eliminate idle billing (#12330)
## Summary

### Before
- E2B sandboxes ran continuously between CoPilot turns, billing for idle
time
- Sandbox timeout caused **termination** (kill), losing all session
state
- No explicit cleanup when sessions were deleted — sandboxes leaked
- Single timeout concept with no separation between pause and kill
semantics

### After
- **Per-turn pause**: `pause_sandbox()` is called in the `finally` block
after every CoPilot turn, stopping billing instantly between turns
(paused sandboxes cost \$0 compute)
- **Auto-pause safety net**: Sandboxes are created with
`lifecycle={"on_timeout": "pause"}` (`pause_timeout` = 4h default) so
they auto-pause rather than terminate if the explicit pause is missed
- **Auto-reconnect**: `AsyncSandbox.connect()` in e2b SDK v2
auto-resumes paused sandboxes transparently — no extra code needed
- **Session delete cleanup**: `kill_sandbox()` is now called in
`delete_chat_session()` to explicitly terminate sandboxes and free
resources
- **Two distinct timeouts**: `pause_timeout` (4h, e2b auto-pause) vs
`redis_ttl` (12h, session key lifetime)

### Key Changes

| File | Change |
|------|--------|
| `pyproject.toml` | Bump `e2b-code-interpreter` `1.x` → `2.x` |
| `e2b_sandbox.py` | Add `pause_sandbox()`, `kill_sandbox()`,
`_act_on_sandbox()` helper; `lifecycle={"on_timeout": "pause"}`;
separate `pause_timeout` / `redis_ttl` params |
| `sdk/service.py` | Call `pause_sandbox()` in `finally` block
**before** transcript upload; use walrus operator for type-safe
`e2b_api_key` narrowing |
| `model.py` | Call `kill_sandbox()` in `delete_chat_session()`; inline
import to avoid circular dependency |
| `config.py` | Add `e2b_active` property; rename `e2b_sandbox_timeout`
default to 4h |
| `e2b_sandbox_test.py` | Add `test_pause_then_reconnect_reuses_sandbox`
test; update all `sandbox_timeout` → `pause_timeout` |

### Verified E2E
- Used real `E2B_API_KEY` from k8s dev cluster to manually verify:
sandbox created → paused → `is_running() == False` → reconnected via
`connect()` → state preserved → killed

## Test plan
- [x] `poetry run pytest backend/copilot/tools/e2b_sandbox_test.py` —
all 19 tests pass
- [x] CI: test (3.11, 3.12, 3.13), types — all green
- [x] E2E verified with real E2B credentials
2026-03-09 14:55:10 +00:00
Ubbe
8063391d0a feat(frontend/copilot): pin interactive tool cards outside reasoning collapse (#12346)
## Summary

<img width="400" height="227" alt="Screenshot 2026-03-09 at 22 43 10"
src="https://github.com/user-attachments/assets/0116e260-860d-4466-9763-e02de2766e50"
/>

<img width="600" height="618" alt="Screenshot 2026-03-09 at 22 43 14"
src="https://github.com/user-attachments/assets/beaa6aca-afa8-483f-ac06-439bf162c951"
/>

- When the copilot stream finishes, tool calls that require user
interaction (credentials, inputs, clarification) are now **pinned**
outside the "Show reasoning" collapse instead of being hidden
- Added `isInteractiveToolPart()` helper that checks tool output's
`type` field against a set of interactive response types
- Modified `splitReasoningAndResponse()` to extract interactive tools
from reasoning into the visible response section
- Added styleguide section with 3 demos: `setup_requirements`,
`agent_details`, and `agent_saved` pinning scenarios

### Interactive response types kept visible:
`setup_requirements`, `agent_details`, `block_details`, `need_login`,
`input_validation_error`, `clarification_needed`, `suggested_goal`,
`agent_preview`, `agent_saved`

Error responses remain in reasoning (LLM explains them in final text).

Closes SECRT-2088

## Test plan
- [ ] Verify copilot stream with interactive tool (e.g. run_agent
requiring credentials) keeps the tool card visible after stream ends
- [ ] Verify non-interactive tools (find_block, bash_exec) still
collapse into "Show reasoning"
- [ ] Verify styleguide page at `/copilot/styleguide` renders the new
"Reasoning Collapse: Interactive Tool Pinning" section correctly
- [ ] Verify `pnpm types`, `pnpm lint`, `pnpm format` all pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 23:12:14 +08:00
Otto
0bbb12d688 fix(frontend/copilot): hide New Chat button on Autopilot homepage (#12321)
Requested by @0ubbe

The **New Chat** button was visible on the Autopilot homepage where
clicking it has no effect (since `sessionId` is already `null`). This
hides the button when no chat session is active, so it only appears when
the user is viewing a conversation and wants to start a new one.

**Changes:**
- `ChatSidebar.tsx` — hide button in both collapsed and expanded sidebar
states when `sessionId` is null
- `MobileDrawer.tsx` — same fix for mobile drawer

---
Co-authored-by: Ubbe <ubbe@users.noreply.github.com>
2026-03-09 22:41:11 +08:00
Reinier van der Leer
19d775c435 Merge commit from fork 2026-03-08 10:25:24 +01:00
264 changed files with 24788 additions and 5170 deletions

View File

@@ -0,0 +1,17 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -0,0 +1,35 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -0,0 +1,16 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -0,0 +1,29 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -0,0 +1,28 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -0,0 +1,31 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -0,0 +1,51 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -0,0 +1,45 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -0,0 +1,40 @@
-- =============================================================
-- View: analytics.auth_activities
-- Looker source alias: ds49 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Tracks authentication events (login, logout, SSO, password
-- reset, etc.) from Supabase's internal audit log.
-- Useful for monitoring sign-in patterns and detecting anomalies.
--
-- SOURCE TABLES
-- auth.audit_log_entries — Supabase internal auth event log
--
-- OUTPUT COLUMNS
-- created_at TIMESTAMPTZ When the auth event occurred
-- actor_id TEXT User ID who triggered the event
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
--
-- WINDOW
-- Rolling 90 days from current date
--
-- EXAMPLE QUERIES
-- -- Daily login counts
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
-- FROM analytics.auth_activities
-- WHERE action = 'login'
-- GROUP BY 1 ORDER BY 1;
--
-- -- SSO vs password login breakdown
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
-- WHERE action = 'login' GROUP BY 1;
-- =============================================================
SELECT
created_at,
payload->>'actor_id' AS actor_id,
payload->>'actor_via_sso' AS actor_via_sso,
payload->>'action' AS action
FROM auth.audit_log_entries
WHERE created_at >= NOW() - INTERVAL '90 days'

View File

@@ -0,0 +1,105 @@
-- =============================================================
-- View: analytics.graph_execution
-- Looker source alias: ds16 | Charts: 21
-- =============================================================
-- DESCRIPTION
-- One row per agent graph execution (last 90 days).
-- Unpacks the JSONB stats column into individual numeric columns
-- and normalises the executionStatus — runs that failed due to
-- insufficient credits are reclassified as 'NO_CREDITS' for
-- easier filtering. Error messages are scrubbed of IDs and URLs
-- to allow safe grouping.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
-- platform.AgentGraph — Agent graph metadata (for name)
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
--
-- OUTPUT COLUMNS
-- id TEXT Execution UUID
-- agentGraphId TEXT Agent graph UUID
-- agentGraphVersion INT Graph version number
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
-- createdAt TIMESTAMPTZ When the execution was queued
-- updatedAt TIMESTAMPTZ Last status update time
-- userId TEXT Owner user UUID
-- agentGraphName TEXT Human-readable agent name
-- cputime DECIMAL Total CPU seconds consumed
-- walltime DECIMAL Total wall-clock seconds
-- node_count DECIMAL Number of nodes in the graph
-- nodes_cputime DECIMAL CPU time across all nodes
-- nodes_walltime DECIMAL Wall time across all nodes
-- execution_cost DECIMAL Credit cost of this execution
-- correctness_score FLOAT AI correctness score (if available)
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Daily execution counts by status
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
-- FROM analytics.graph_execution
-- GROUP BY 1, 2 ORDER BY 1;
--
-- -- Average cost per execution by agent
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'COMPLETED'
-- GROUP BY 1 ORDER BY avg_cost DESC;
--
-- -- Top error messages
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
-- =============================================================
SELECT
ge."id" AS id,
ge."agentGraphId" AS agentGraphId,
ge."agentGraphVersion" AS agentGraphVersion,
CASE
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
AND (
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
)
THEN 'NO_CREDITS'
ELSE CAST(ge."executionStatus" AS TEXT)
END AS executionStatus,
ge."createdAt" AS createdAt,
ge."updatedAt" AS updatedAt,
ge."userId" AS userId,
g."name" AS agentGraphName,
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage
FROM platform."AgentGraphExecution" ge
LEFT JOIN platform."AgentGraph" g
ON ge."agentGraphId" = g."id"
AND ge."agentGraphVersion" = g."version"
LEFT JOIN (
SELECT DISTINCT ON ("userId", "agentGraphId")
"userId", "agentGraphId",
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
FROM platform."LibraryAgent"
WHERE "isDeleted" = FALSE
AND "isArchived" = FALSE
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,101 @@
-- =============================================================
-- View: analytics.node_block_execution
-- Looker source alias: ds14 | Charts: 11
-- =============================================================
-- DESCRIPTION
-- One row per node (block) execution (last 90 days).
-- Unpacks stats JSONB and joins to identify which block type
-- was run. For failed nodes, joins the error output and
-- scrubs it for safe grouping.
--
-- SOURCE TABLES
-- platform.AgentNodeExecution — Node execution records
-- platform.AgentNode — Node → block mapping
-- platform.AgentBlock — Block name/ID
-- platform.AgentNodeExecutionInputOutput — Error output values
--
-- OUTPUT COLUMNS
-- id TEXT Node execution UUID
-- agentGraphExecutionId TEXT Parent graph execution UUID
-- agentNodeId TEXT Node UUID within the graph
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
-- addedTime TIMESTAMPTZ When the node was queued
-- queuedTime TIMESTAMPTZ When it entered the queue
-- startedTime TIMESTAMPTZ When execution started
-- endedTime TIMESTAMPTZ When execution finished
-- inputSize BIGINT Input payload size in bytes
-- outputSize BIGINT Output payload size in bytes
-- walltime NUMERIC Wall-clock seconds for this node
-- cputime NUMERIC CPU seconds for this node
-- llmRetryCount INT Number of LLM retries
-- llmCallCount INT Number of LLM API calls made
-- inputTokenCount BIGINT LLM input tokens consumed
-- outputTokenCount BIGINT LLM output tokens produced
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
-- blockId TEXT Block UUID
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
-- errorMessage TEXT Raw error output (only set when FAILED)
--
-- WINDOW
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Most-used blocks by execution count
-- SELECT "blockName", COUNT(*) AS executions,
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
-- FROM analytics.node_block_execution
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
--
-- -- Average LLM token usage per block
-- SELECT "blockName",
-- AVG("inputTokenCount") AS avg_input_tokens,
-- AVG("outputTokenCount") AS avg_output_tokens
-- FROM analytics.node_block_execution
-- WHERE "llmCallCount" > 0
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
--
-- -- Top failure reasons
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
-- FROM analytics.node_block_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
-- =============================================================
SELECT
ne."id" AS id,
ne."agentGraphExecutionId" AS agentGraphExecutionId,
ne."agentNodeId" AS agentNodeId,
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
ne."addedTime" AS addedTime,
ne."queuedTime" AS queuedTime,
ne."startedTime" AS startedTime,
ne."endedTime" AS endedTime,
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
b."name" AS blockName,
b."id" AS blockId,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM eio."data"::text),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage,
eio."data" AS errorMessage
FROM platform."AgentNodeExecution" ne
LEFT JOIN platform."AgentNode" nd
ON ne."agentNodeId" = nd."id"
LEFT JOIN platform."AgentBlock" b
ON nd."agentBlockId" = b."id"
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
ON eio."referencedByOutputExecId" = ne."id"
AND eio."name" = 'error'
AND ne."executionStatus" = 'FAILED'
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,97 @@
-- =============================================================
-- View: analytics.retention_agent
-- Looker source alias: ds35 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention broken down per individual agent.
-- Cohort = week of a user's first use of THAT specific agent.
-- Tells you which agents keep users coming back vs. one-shot
-- use. Only includes cohorts from the last 180 days.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records (user × agent × time)
-- platform.AgentGraph — Agent names
--
-- OUTPUT COLUMNS
-- agent_id TEXT Agent graph UUID
-- agent_label TEXT 'AgentName [first8chars]'
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
-- cohort_week_start DATE Week users first ran this agent
-- cohort_label TEXT ISO week label
-- cohort_label_n TEXT ISO week label with cohort size
-- user_lifetime_week INT Weeks since first use of this agent
-- cohort_users BIGINT Users in this cohort for this agent
-- active_users BIGINT Users who ran the agent again in week k
-- retention_rate FLOAT active_users / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
-- agent_total_users BIGINT Total users across all cohorts for this agent
--
-- EXAMPLE QUERIES
-- -- Best-retained agents at week 2
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
-- FROM analytics.retention_agent
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
--
-- -- Agents with most unique users
-- SELECT DISTINCT agent_label, agent_total_users
-- FROM analytics.retention_agent
-- ORDER BY agent_total_users DESC LIMIT 20;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e
),
first_use AS (
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1,2
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
),
active_counts AS (
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
),
cohort_sizes AS (
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
),
cohort_caps AS (
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
),
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
SELECT
g.agent_id,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(ac.active_users,0) AS active_users,
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
COALESCE(atu.agent_total_users,0) AS agent_total_users
FROM grid g
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,81 @@
-- =============================================================
-- View: analytics.retention_execution_daily
-- Looker source alias: ds111 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on agent executions.
-- Cohort anchor = day of user's FIRST ever execution.
-- Only includes cohorts from the last 90 days, up to day 30.
-- Great for early engagement analysis (did users run another
-- agent the next day?).
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_daily.
-- cohort_day_start = day of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Day-3 execution retention
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
-- FROM analytics.retention_execution_daily
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('day', e."createdAt")::date AS day_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fe.cohort_day_start,
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_exec fe USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -0,0 +1,81 @@
-- =============================================================
-- View: analytics.retention_execution_weekly
-- Looker source alias: ds92 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on agent executions.
-- Cohort anchor = week of user's FIRST ever agent execution
-- (not first login). Only includes cohorts from the last 180 days.
-- Useful when you care about product engagement, not just visits.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_weekly.
-- cohort_week_start = week of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Week-2 execution retention
-- SELECT cohort_label, retention_rate_bounded
-- FROM analytics.retention_execution_weekly
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fe.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,94 @@
-- =============================================================
-- View: analytics.retention_login_daily
-- Looker source alias: ds112 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on login sessions.
-- Same logic as retention_login_weekly but at day granularity,
-- showing up to day 30 for cohorts from the last 90 days.
-- Useful for analysing early activation (days 1-7) in detail.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
-- cohort_day_start DATE First day the cohort logged in
-- cohort_label TEXT Date string (e.g. '2025-03-01')
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
-- user_lifetime_day INT Days since first login (0 = signup day)
-- cohort_users BIGINT Total users in cohort
-- active_users_bounded BIGINT Users active on exactly day k
-- retained_users_unbounded BIGINT Users active any time on/after day k
-- retention_rate_bounded FLOAT bounded / cohort_users
-- retention_rate_unbounded FLOAT unbounded / cohort_users
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
--
-- EXAMPLE QUERIES
-- -- Day-1 retention rate (came back next day)
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
-- FROM analytics.retention_login_daily
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
--
-- -- Average retention curve across all cohorts
-- SELECT user_lifetime_day,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
-- FROM analytics.retention_login_daily
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('day', s.created_at)::date AS day_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fl.cohort_day_start,
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_login fl USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -0,0 +1,96 @@
-- =============================================================
-- View: analytics.retention_login_onboarded_weekly
-- Looker source alias: ds101 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention from login sessions, restricted to
-- users who "onboarded" — defined as running at least one
-- agent within 365 days of their first login.
-- Filters out users who signed up but never activated,
-- giving a cleaner view of engaged-user retention.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
-- platform.AgentGraphExecution — Used to identify onboarders
--
-- OUTPUT COLUMNS
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
-- retention_rate_bounded, retention_rate_unbounded, etc.)
-- Only difference: cohort is filtered to onboarded users only.
--
-- EXAMPLE QUERIES
-- -- Compare week-4 retention: all users vs onboarded only
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
-- UNION ALL
-- SELECT 'onboarded', AVG(retention_rate_bounded)
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login_all AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
onboarders AS (
SELECT fl.user_id FROM first_login_all fl
WHERE EXISTS (
SELECT 1 FROM platform."AgentGraphExecution" e
WHERE e."userId"::text = fl.user_id
AND e."createdAt" >= fl.first_login_time
AND e."createdAt" < fl.first_login_time
+ make_interval(days => (SELECT onboarding_window_days FROM params))
)
),
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,103 @@
-- =============================================================
-- View: analytics.retention_login_weekly
-- Looker source alias: ds83 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on login sessions.
-- Users are grouped by the ISO week of their first ever login.
-- For each cohort × lifetime-week combination, outputs both:
-- - bounded rate: % active in exactly that week
-- - unbounded rate: % who were ever active on or after that week
-- Weeks are capped to the cohort's actual age (no future data points).
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- HOW TO READ THE OUTPUT
-- cohort_week_start The Monday of the week users first logged in
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
-- retention_rate_bounded = active_users_bounded / cohort_users
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
--
-- OUTPUT COLUMNS
-- cohort_week_start DATE First day of the cohort's signup week
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
-- user_lifetime_week INT Weeks since first login (0 = signup week)
-- cohort_users BIGINT Total users in this cohort (denominator)
-- active_users_bounded BIGINT Users active in exactly week k
-- retained_users_unbounded BIGINT Users active any time on/after week k
-- retention_rate_bounded FLOAT bounded active / cohort_users
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
--
-- EXAMPLE QUERIES
-- -- Week-1 retention rate per cohort
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
-- FROM analytics.retention_login_weekly
-- WHERE user_lifetime_week = 1
-- ORDER BY cohort_week_start;
--
-- -- Overall average retention curve (all cohorts combined)
-- SELECT user_lifetime_week,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
-- FROM analytics.retention_login_weekly
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week

View File

@@ -0,0 +1,71 @@
-- =============================================================
-- View: analytics.user_block_spending
-- Looker source alias: ds6 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per credit transaction (last 90 days).
-- Shows how users spend credits broken down by block type,
-- LLM provider and model. Joins node execution stats for
-- token-level detail.
--
-- SOURCE TABLES
-- platform.CreditTransaction — Credit debit/credit records
-- platform.AgentNodeExecution — Node execution stats (for token counts)
--
-- OUTPUT COLUMNS
-- transactionKey TEXT Unique transaction identifier
-- userId TEXT User who was charged
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
-- transactionTime TIMESTAMPTZ When the transaction was recorded
-- blockId TEXT Block UUID that triggered the spend
-- blockName TEXT Human-readable block name
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
-- node_exec_id TEXT Linked node execution UUID
-- llm_call_count INT LLM API calls made in that execution
-- llm_retry_count INT LLM retries in that execution
-- llm_input_token_count INT Input tokens consumed
-- llm_output_token_count INT Output tokens produced
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend per user (last 90 days)
-- SELECT "userId", SUM("negativeAmount") AS total_spent
-- FROM analytics.user_block_spending
-- WHERE "transactionType" = 'USAGE'
-- GROUP BY 1 ORDER BY total_spent DESC;
--
-- -- Spend by LLM provider + model
-- SELECT "llm_provider", "llm_model",
-- SUM("negativeAmount") AS total_cost,
-- SUM("llm_input_token_count") AS input_tokens,
-- SUM("llm_output_token_count") AS output_tokens
-- FROM analytics.user_block_spending
-- WHERE "llm_provider" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
-- =============================================================
SELECT
c."transactionKey" AS transactionKey,
c."userId" AS userId,
c."amount" AS amount,
c."amount" * -1 AS negativeAmount,
c."type" AS transactionType,
c."createdAt" AS transactionTime,
c.metadata->>'block_id' AS blockId,
c.metadata->>'block' AS blockName,
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
c.metadata->'input'->>'model' AS llm_model,
c.metadata->>'node_exec_id' AS node_exec_id,
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
FROM platform."CreditTransaction" c
LEFT JOIN platform."AgentNodeExecution" ne
ON (c.metadata->>'node_exec_id') = ne."id"::text
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,45 @@
-- =============================================================
-- View: analytics.user_onboarding
-- Looker source alias: ds68 | Charts: 3
-- =============================================================
-- DESCRIPTION
-- One row per user onboarding record. Contains the user's
-- stated usage reason, selected integrations, completed
-- onboarding steps and optional first agent selection.
-- Full history (no date filter) since onboarding happens
-- once per user.
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding state per user
--
-- OUTPUT COLUMNS
-- id TEXT Onboarding record UUID
-- createdAt TIMESTAMPTZ When onboarding started
-- updatedAt TIMESTAMPTZ Last update to onboarding state
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
-- integrations TEXT[] Array of integration names the user selected
-- userId TEXT User UUID
-- completedSteps TEXT[] Array of onboarding step enums completed
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
--
-- EXAMPLE QUERIES
-- -- Usage reason breakdown
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
--
-- -- Completion rate per step
-- SELECT step, COUNT(*) AS users_completed
-- FROM analytics.user_onboarding
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
-- GROUP BY 1 ORDER BY users_completed DESC;
-- =============================================================
SELECT
id,
"createdAt",
"updatedAt",
"usageReason",
integrations,
"userId",
"completedSteps",
"selectedStoreListingVersionId"
FROM platform."UserOnboarding"

View File

@@ -0,0 +1,100 @@
-- =============================================================
-- View: analytics.user_onboarding_funnel
-- Looker source alias: ds74 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated onboarding funnel showing how many users
-- completed each step and the drop-off percentage from the
-- previous step. One row per onboarding step (all 22 steps
-- always present, even with 0 completions — prevents sparse
-- gaps from making LAG compare the wrong predecessors).
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding records with completedSteps array
--
-- OUTPUT COLUMNS
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
-- step_order INT Numeric position in the funnel (1=first, 22=last)
-- users_completed BIGINT Distinct users who completed this step
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
--
-- STEP ORDER
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
-- 7 CONGRATS 15 VISIT_COPILOT
-- 8 GET_RESULTS 16 RE_RUN_AGENT
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full funnel
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
--
-- -- Biggest drop-off point
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
-- ORDER BY pct_from_prev ASC LIMIT 3;
-- =============================================================
WITH all_steps AS (
-- Complete ordered grid of all 22 steps so zero-completion steps
-- are always present, keeping LAG comparisons correct.
SELECT step_name, step_order
FROM (VALUES
('WELCOME', 1),
('USAGE_REASON', 2),
('INTEGRATIONS', 3),
('AGENT_CHOICE', 4),
('AGENT_NEW_RUN', 5),
('AGENT_INPUT', 6),
('CONGRATS', 7),
('GET_RESULTS', 8),
('MARKETPLACE_VISIT', 9),
('MARKETPLACE_ADD_AGENT', 10),
('MARKETPLACE_RUN_AGENT', 11),
('BUILDER_OPEN', 12),
('BUILDER_SAVE_AGENT', 13),
('BUILDER_RUN_AGENT', 14),
('VISIT_COPILOT', 15),
('RE_RUN_AGENT', 16),
('SCHEDULE_AGENT', 17),
('RUN_AGENTS', 18),
('RUN_3_DAYS', 19),
('TRIGGER_WEBHOOK', 20),
('RUN_14_DAYS', 21),
('RUN_AGENTS_100', 22)
) AS t(step_name, step_order)
),
raw AS (
SELECT
u."userId",
step_txt::text AS step
FROM platform."UserOnboarding" u
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
),
step_counts AS (
SELECT step, COUNT(DISTINCT "userId") AS users_completed
FROM raw GROUP BY step
),
funnel AS (
SELECT
a.step_name AS step,
a.step_order,
COALESCE(sc.users_completed, 0) AS users_completed,
ROUND(
100.0 * COALESCE(sc.users_completed, 0)
/ NULLIF(
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
0
),
2
) AS pct_from_prev
FROM all_steps a
LEFT JOIN step_counts sc ON sc.step = a.step_name
)
SELECT * FROM funnel ORDER BY step_order

View File

@@ -0,0 +1,41 @@
-- =============================================================
-- View: analytics.user_onboarding_integration
-- Looker source alias: ds75 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated count of users who selected each integration
-- during onboarding. One row per integration type, sorted
-- by popularity.
--
-- SOURCE TABLES
-- platform.UserOnboarding — integrations array column
--
-- OUTPUT COLUMNS
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
-- users_with_integration BIGINT Distinct users who selected this integration
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full integration popularity ranking
-- SELECT * FROM analytics.user_onboarding_integration;
--
-- -- Top 5 integrations
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
-- =============================================================
WITH exploded AS (
SELECT
u."userId" AS user_id,
UNNEST(u."integrations") AS integration
FROM platform."UserOnboarding" u
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
)
SELECT
integration,
COUNT(DISTINCT user_id) AS users_with_integration
FROM exploded
WHERE integration IS NOT NULL AND integration <> ''
GROUP BY integration
ORDER BY users_with_integration DESC

View File

@@ -0,0 +1,145 @@
-- =============================================================
-- View: analytics.users_activities
-- Looker source alias: ds56 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per user with lifetime activity summary.
-- Joins login sessions with agent graphs, executions and
-- node-level runs to give a full picture of how engaged
-- each user is. Includes a convenience flag for 7-day
-- activation (did the user return at least 7 days after
-- their first login?).
--
-- SOURCE TABLES
-- auth.sessions — Login/session records
-- platform.AgentGraph — Graphs (agents) built by the user
-- platform.AgentGraphExecution — Agent run history
-- platform.AgentNodeExecution — Individual block execution history
--
-- PERFORMANCE NOTE
-- Each CTE aggregates its own table independently by userId.
-- This avoids the fan-out that occurs when driving every join
-- from user_logins across the two largest tables
-- (AgentGraphExecution and AgentNodeExecution).
--
-- OUTPUT COLUMNS
-- user_id TEXT Supabase user UUID
-- first_login_time TIMESTAMPTZ First ever session created_at
-- last_login_time TIMESTAMPTZ Most recent session created_at
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
-- agent_runs BIGINT Total graph execution count (0 if none)
-- node_execution_count BIGINT Total node executions across all runs
-- node_execution_failed BIGINT Node executions with FAILED status
-- node_execution_completed BIGINT Node executions with COMPLETED status
-- node_execution_terminated BIGINT Node executions with TERMINATED status
-- node_execution_queued BIGINT Node executions with QUEUED status
-- node_execution_running BIGINT Node executions with RUNNING status
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
-- node_execution_review BIGINT Node executions with REVIEW status
--
-- EXAMPLE QUERIES
-- -- Users who ran at least one agent and returned after 7 days
-- SELECT COUNT(*) FROM analytics.users_activities
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
--
-- -- Top 10 most active users by agent runs
-- SELECT user_id, agent_runs, node_execution_count
-- FROM analytics.users_activities
-- ORDER BY agent_runs DESC LIMIT 10;
--
-- -- 7-day activation rate
-- SELECT
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
-- AS activation_rate
-- FROM analytics.users_activities;
-- =============================================================
WITH user_logins AS (
SELECT
user_id::text AS user_id,
MIN(created_at) AS first_login_time,
MAX(created_at) AS last_login_time,
GREATEST(
MAX(refreshed_at)::timestamptz,
MAX(created_at)::timestamptz
) AS last_visit_time
FROM auth.sessions
GROUP BY user_id
),
user_agents AS (
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
SELECT
"userId"::text AS user_id,
MAX("updatedAt") AS last_agent_save_time,
COUNT(DISTINCT "id") AS agent_count
FROM platform."AgentGraph"
WHERE "isActive"
GROUP BY "userId"
),
user_graph_runs AS (
-- Aggregate AgentGraphExecution directly by userId
SELECT
"userId"::text AS user_id,
MIN("createdAt") AS first_agent_run_time,
MAX("createdAt") AS last_agent_run_time,
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
COUNT("id") AS agent_runs
FROM platform."AgentGraphExecution"
GROUP BY "userId"
),
user_node_runs AS (
-- Aggregate AgentNodeExecution directly; resolve userId via a
-- single join to AgentGraphExecution instead of fanning out from
-- user_logins through both large tables.
SELECT
g."userId"::text AS user_id,
COUNT(*) AS node_execution_count,
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
FROM platform."AgentNodeExecution" n
JOIN platform."AgentGraphExecution" g
ON g."id" = n."agentGraphExecutionId"
GROUP BY g."userId"
)
SELECT
ul.user_id,
ul.first_login_time,
ul.last_login_time,
ul.last_visit_time,
ua.last_agent_save_time,
COALESCE(ua.agent_count, 0) AS agent_count,
gr.first_agent_run_time,
gr.last_agent_run_time,
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
COALESCE(gr.agent_runs, 0) AS agent_runs,
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
CASE
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
ELSE NULL
END AS is_active_after_7d,
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
COALESCE(nr.node_execution_review, 0) AS node_execution_review
FROM user_logins ul
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id

View File

@@ -37,6 +37,10 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
## ===== SIGNUP / INVITE GATE ===== ##
# Set to true to require an invite before users can sign up
ENABLE_INVITE_GATE=false
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000

View File

@@ -1,8 +1,17 @@
from pydantic import BaseModel
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
@@ -14,3 +23,70 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
return cls.model_validate(record.model_dump())
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
pagination: Pagination
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]
@classmethod
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return cls(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
InvitedUserResponse.from_record(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)

View File

@@ -0,0 +1,137 @@
import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Query, Security, UploadFile
from backend.data.invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.util.models import Pagination
from .model import (
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users, total = await list_invited_users(page=page, page_size=page_size)
return InvitedUsersResponse(
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
pagination=Pagination(
total_items=total,
total_pages=max(1, math.ceil(total / page_size)),
current_page=page,
page_size=page_size,
),
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s creating invited user for %s",
admin_user_id,
mask_email(request.email),
)
invited_user = await create_invited_user(request.email, request.name)
logger.info(
"Admin user %s created invited user %s",
admin_user_id,
invited_user.id,
)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return BulkInvitedUsersResponse.from_result(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
)
invited_user = await revoke_invited_user(invited_user_id)
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retrying Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)

View File

@@ -0,0 +1,168 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=([_sample_invited_user()], 1)),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
assert data["pagination"]["total_items"] == 1
assert data["pagination"]["current_page"] == 1
assert data["pagination"]["page_size"] == 50
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -27,7 +27,14 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -52,6 +59,8 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -117,6 +126,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -126,6 +137,7 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
is_processing: bool
class ListSessionsResponse(BaseModel):
@@ -184,6 +196,28 @@ async def list_sessions(
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
if sessions:
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
"status",
)
statuses = await pipe.execute()
processing_set = {
session.session_id
for session, st in zip(sessions, statuses)
if st == "running"
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; " "defaulting to empty"
)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
@@ -191,6 +225,7 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
is_processing=session.session_id in processing_set,
)
for session in sessions
],
@@ -265,12 +300,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, config.e2b_api_key)
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
@@ -362,6 +397,10 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -369,6 +408,26 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -469,6 +528,17 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -827,6 +897,36 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts."""
prompts: list[str]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts for the authenticated user.
Returns personalized quick-action prompts based on the user's
business understanding. Returns an empty list if no custom prompts
are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None:
return SuggestedPromptsResponse(prompts=[])
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
# ========== Configuration ==========

View File

@@ -1,6 +1,7 @@
"""Tests for chat API routes: session title update and file attachment validation."""
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
from unittest.mock import AsyncMock
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
import fastapi.testclient
@@ -249,3 +250,130 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding and prompts gets them back."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but no prompts gets empty list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = []
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -36,6 +41,38 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -110,14 +147,16 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -160,30 +199,26 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -236,7 +271,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
# Collect node_exec_ids that need auto-approval and resolve their node_ids
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -244,29 +279,16 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -281,13 +303,11 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -302,22 +322,20 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -165,7 +165,6 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
image_url: str | None
@@ -206,7 +205,9 @@ class LibraryAgent(pydantic.BaseModel):
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool
can_access_graph: bool = pydantic.Field(
description="Indicates whether the same user owns the corresponding graph"
)
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
@@ -324,7 +325,6 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -67,7 +66,6 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -131,7 +129,6 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -184,7 +181,6 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -55,6 +55,7 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
user = await get_or_activate_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -165,7 +163,8 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
) -> dict[str, str]:
await update_user_email(user_id, email)
@@ -179,10 +178,16 @@ async def update_user_email_route(
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
user_id: Annotated[str, Security(get_user_id)],
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
try:
user = await get_user_by_id(user_id)
except ValueError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail="User not found. Please complete activation via /auth/user first.",
)
return TimezoneResponse(timezone=user.timezone)
@@ -193,7 +198,8 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_create_user",
"backend.api.features.v1.get_or_activate_user",
return_value=mock_user,
)

View File

@@ -94,3 +94,8 @@ class NotificationPayload(pydantic.BaseModel):
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
class CopilotCompletionPayload(NotificationPayload):
session_id: str
status: Literal["completed", "failed"]

View File

@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -311,6 +312,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
disabled=True,
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code(
execution = await sandbox.run_code( # type: ignore[attr-defined]
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -96,6 +96,7 @@ class SendEmailBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -0,0 +1,3 @@
def github_repo_path(repo_url: str) -> str:
"""Extract 'owner/repo' from a GitHub repository URL."""
return repo_url.replace("https://github.com/", "")

View File

@@ -0,0 +1,374 @@
import asyncio
from enum import StrEnum
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListCommitsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to list commits from",
default="main",
)
per_page: int = SchemaField(
description="Number of commits to return (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class CommitItem(TypedDict):
sha: str
message: str
author: str
date: str
url: str
commit: CommitItem = SchemaField(
title="Commit", description="A commit with its details"
)
commits: list[CommitItem] = SchemaField(
description="List of commits with their details"
)
error: str = SchemaField(description="Error message if listing commits failed")
def __init__(self):
super().__init__(
id="8b13f579-d8b6-4dc2-a140-f770428805de",
description="This block lists commits on a branch in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommitsBlock.Input,
output_schema=GithubListCommitsBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"commits",
[
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
],
),
(
"commit",
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
},
),
],
test_mock={
"list_commits": lambda *args, **kwargs: [
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
]
},
)
@staticmethod
async def list_commits(
credentials: GithubCredentials,
repo_url: str,
branch: str,
per_page: int,
page: int,
) -> list[Output.CommitItem]:
api = get_api(credentials)
commits_url = repo_url + "/commits"
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
response = await api.get(commits_url, params=params)
data = response.json()
repo_path = github_repo_path(repo_url)
return [
GithubListCommitsBlock.Output.CommitItem(
sha=c["sha"],
message=c["commit"]["message"],
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
date=(c["commit"].get("author") or {}).get("date", ""),
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
)
for c in data
]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
commits = await self.list_commits(
credentials,
input_data.repo_url,
input_data.branch,
input_data.per_page,
input_data.page,
)
yield "commits", commits
for commit in commits:
yield "commit", commit
except Exception as e:
yield "error", str(e)
class FileOperation(StrEnum):
"""File operations for GithubMultiFileCommitBlock.
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
between creation and update — the blob is placed at the given path regardless
of whether a file already exists there).
DELETE removes a file from the tree.
"""
UPSERT = "upsert"
DELETE = "delete"
class FileOperationInput(TypedDict):
path: str
content: str
operation: FileOperation
class GithubMultiFileCommitBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch to commit to",
placeholder="feature-branch",
)
commit_message: str = SchemaField(
description="Commit message",
placeholder="Add new feature",
)
files: list[FileOperationInput] = SchemaField(
description=(
"List of file operations. Each item has: "
"'path' (file path), 'content' (file content, ignored for delete), "
"'operation' (upsert/delete)"
),
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the new commit")
url: str = SchemaField(description="URL of the new commit")
error: str = SchemaField(description="Error message if the commit failed")
def __init__(self):
super().__init__(
id="389eee51-a95e-4230-9bed-92167a327802",
description=(
"This block creates a single commit with multiple file "
"upsert/delete operations using the Git Trees API."
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMultiFileCommitBlock.Input,
output_schema=GithubMultiFileCommitBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "Add files",
"files": [
{
"path": "src/new.py",
"content": "print('hello')",
"operation": "upsert",
},
{
"path": "src/old.py",
"content": "",
"operation": "delete",
},
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "newcommitsha"),
("url", "https://github.com/owner/repo/commit/newcommitsha"),
],
test_mock={
"multi_file_commit": lambda *args, **kwargs: (
"newcommitsha",
"https://github.com/owner/repo/commit/newcommitsha",
)
},
)
@staticmethod
async def multi_file_commit(
credentials: GithubCredentials,
repo_url: str,
branch: str,
commit_message: str,
files: list[FileOperationInput],
) -> tuple[str, str]:
api = get_api(credentials)
safe_branch = quote(branch, safe="")
# 1. Get the latest commit SHA for the branch
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
response = await api.get(ref_url)
ref_data = response.json()
latest_commit_sha = ref_data["object"]["sha"]
# 2. Get the tree SHA of the latest commit
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
response = await api.get(commit_url)
commit_data = response.json()
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str) -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": "utf-8"},
)
return blob_response.json()["sha"]
tree_entries: list[dict] = []
upsert_files = []
for file_op in files:
path = file_op["path"]
operation = FileOperation(file_op.get("operation", "upsert"))
if operation == FileOperation.DELETE:
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": None, # null SHA = delete
}
)
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently
if upsert_files:
blob_shas = await asyncio.gather(
*[_create_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# 4. Create a new tree
tree_url = repo_url + "/git/trees"
tree_response = await api.post(
tree_url,
json={"base_tree": base_tree_sha, "tree": tree_entries},
)
new_tree_sha = tree_response.json()["sha"]
# 5. Create a new commit
new_commit_url = repo_url + "/git/commits"
commit_response = await api.post(
new_commit_url,
json={
"message": commit_message,
"tree": new_tree_sha,
"parents": [latest_commit_sha],
},
)
new_commit_sha = commit_response.json()["sha"]
# 6. Update the branch reference
try:
await api.patch(
ref_url,
json={"sha": new_commit_sha},
)
except Exception as e:
raise RuntimeError(
f"Commit {new_commit_sha} was created but failed to update "
f"ref heads/{branch}: {e}. "
f"You can recover by manually updating the branch to {new_commit_sha}."
) from e
repo_path = github_repo_path(repo_url)
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
return new_commit_sha, commit_web_url
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
input_data.files,
)
yield "sha", sha
yield "url", url
except Exception as e:
yield "error", str(e)

View File

@@ -1,4 +1,5 @@
import re
from typing import Literal
from typing_extensions import TypedDict
@@ -20,6 +21,8 @@ from ._auth import (
GithubCredentialsInput,
)
MergeMethod = Literal["merge", "squash", "rebase"]
class GithubListPullRequestsBlock(Block):
class Input(BlockSchemaInput):
@@ -558,12 +561,109 @@ class GithubListPRReviewersBlock(Block):
yield "reviewer", reviewer
class GithubMergePullRequestBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
pr_url: str = SchemaField(
description="URL of the GitHub pull request",
placeholder="https://github.com/owner/repo/pull/1",
)
merge_method: MergeMethod = SchemaField(
description="Merge method to use: merge, squash, or rebase",
default="merge",
)
commit_title: str = SchemaField(
description="Title for the merge commit (optional, used for merge and squash)",
default="",
)
commit_message: str = SchemaField(
description="Message for the merge commit (optional, used for merge and squash)",
default="",
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the merge commit")
merged: bool = SchemaField(description="Whether the PR was merged")
message: str = SchemaField(description="Merge status message")
error: str = SchemaField(description="Error message if the merge failed")
def __init__(self):
super().__init__(
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
description="This block merges a pull request using merge, squash, or rebase.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMergePullRequestBlock.Input,
output_schema=GithubMergePullRequestBlock.Output,
test_input={
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "abc123"),
("merged", True),
("message", "Pull Request successfully merged"),
],
test_mock={
"merge_pr": lambda *args, **kwargs: (
"abc123",
True,
"Pull Request successfully merged",
)
},
is_sensitive_action=True,
)
@staticmethod
async def merge_pr(
credentials: GithubCredentials,
pr_url: str,
merge_method: MergeMethod,
commit_title: str,
commit_message: str,
) -> tuple[str, bool, str]:
api = get_api(credentials)
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
data: dict[str, str] = {"merge_method": merge_method}
if commit_title:
data["commit_title"] = commit_title
if commit_message:
data["commit_message"] = commit_message
response = await api.put(merge_url, json=data)
result = response.json()
return result["sha"], result["merged"], result["message"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, merged, message = await self.merge_pr(
credentials,
input_data.pr_url,
input_data.merge_method,
input_data.commit_title,
input_data.commit_message,
)
yield "sha", sha
yield "merged", merged
yield "message", message
except Exception as e:
yield "error", str(e)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"
scheme, base_url, pr_number = match.groups()
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"

View File

@@ -1,5 +1,3 @@
import base64
from typing_extensions import TypedDict
from backend.blocks._base import (
@@ -19,6 +17,7 @@ from ._auth import (
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListTagsBlock(Block):
@@ -89,7 +88,7 @@ class GithubListTagsBlock(Block):
tags_url = repo_url + "/tags"
response = await api.get(tags_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
repo_path = github_repo_path(repo_url)
tags: list[GithubListTagsBlock.Output.TagItem] = [
{
"name": tag["name"],
@@ -115,101 +114,6 @@ class GithubListTagsBlock(Block):
yield "tag", tag
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(branches_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
branches = await self.list_branches(
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
class GithubListDiscussionsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -283,7 +187,7 @@ class GithubListDiscussionsBlock(Block):
) -> list[Output.DiscussionItem]:
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = repo_url.replace("https://github.com/", "")
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
query = """
query($owner: String!, $repo: String!, $num: Int!) {
@@ -416,564 +320,6 @@ class GithubListReleasesBlock(Block):
yield "release", release
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to master)",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubCreateRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -1103,7 +449,7 @@ class GithubListStargazersBlock(Block):
def __init__(self):
super().__init__(
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
description="This block lists all users who have starred a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListStargazersBlock.Input,
@@ -1172,3 +518,230 @@ class GithubListStargazersBlock(Block):
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer
class GithubGetRepositoryInfoBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
name: str = SchemaField(description="Repository name")
full_name: str = SchemaField(description="Full repository name (owner/repo)")
description: str = SchemaField(description="Repository description")
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
private: bool = SchemaField(description="Whether the repository is private")
html_url: str = SchemaField(description="Web URL of the repository")
clone_url: str = SchemaField(description="Git clone URL")
stars: int = SchemaField(description="Number of stars")
forks: int = SchemaField(description="Number of forks")
open_issues: int = SchemaField(description="Number of open issues")
error: str = SchemaField(
description="Error message if fetching repo info failed"
)
def __init__(self):
super().__init__(
id="59d4f241-968a-4040-95da-348ac5c5ce27",
description="This block retrieves metadata about a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryInfoBlock.Input,
output_schema=GithubGetRepositoryInfoBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("name", "repo"),
("full_name", "owner/repo"),
("description", "A test repo"),
("default_branch", "main"),
("private", False),
("html_url", "https://github.com/owner/repo"),
("clone_url", "https://github.com/owner/repo.git"),
("stars", 42),
("forks", 5),
("open_issues", 3),
],
test_mock={
"get_repo_info": lambda *args, **kwargs: {
"name": "repo",
"full_name": "owner/repo",
"description": "A test repo",
"default_branch": "main",
"private": False,
"html_url": "https://github.com/owner/repo",
"clone_url": "https://github.com/owner/repo.git",
"stargazers_count": 42,
"forks_count": 5,
"open_issues_count": 3,
}
},
)
@staticmethod
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
api = get_api(credentials)
response = await api.get(repo_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.get_repo_info(credentials, input_data.repo_url)
yield "name", data["name"]
yield "full_name", data["full_name"]
yield "description", data.get("description", "") or ""
yield "default_branch", data["default_branch"]
yield "private", data["private"]
yield "html_url", data["html_url"]
yield "clone_url", data["clone_url"]
yield "stars", data["stargazers_count"]
yield "forks", data["forks_count"]
yield "open_issues", data["open_issues_count"]
except Exception as e:
yield "error", str(e)
class GithubForkRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to fork",
placeholder="https://github.com/owner/repo",
)
organization: str = SchemaField(
description="Organization to fork into (leave empty to fork to your account)",
default="",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the forked repository")
clone_url: str = SchemaField(description="Git clone URL of the fork")
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
error: str = SchemaField(description="Error message if the fork failed")
def __init__(self):
super().__init__(
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
description="This block forks a GitHub repository to your account or an organization.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubForkRepositoryBlock.Input,
output_schema=GithubForkRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"organization": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/myuser/repo"),
("clone_url", "https://github.com/myuser/repo.git"),
("full_name", "myuser/repo"),
],
test_mock={
"fork_repo": lambda *args, **kwargs: (
"https://github.com/myuser/repo",
"https://github.com/myuser/repo.git",
"myuser/repo",
)
},
)
@staticmethod
async def fork_repo(
credentials: GithubCredentials,
repo_url: str,
organization: str,
) -> tuple[str, str, str]:
api = get_api(credentials)
forks_url = repo_url + "/forks"
data: dict[str, str] = {}
if organization:
data["organization"] = organization
response = await api.post(forks_url, json=data)
result = response.json()
return result["html_url"], result["clone_url"], result["full_name"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, clone_url, full_name = await self.fork_repo(
credentials,
input_data.repo_url,
input_data.organization,
)
yield "url", url
yield "clone_url", clone_url
yield "full_name", full_name
except Exception as e:
yield "error", str(e)
class GithubStarRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to star",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the star operation")
error: str = SchemaField(description="Error message if starring failed")
def __init__(self):
super().__init__(
id="bd700764-53e3-44dd-a969-d1854088458f",
description="This block stars a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubStarRepositoryBlock.Input,
output_schema=GithubStarRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Repository starred successfully")],
test_mock={
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
},
)
@staticmethod
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
api = get_api(credentials, convert_urls=False)
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
await api.put(
f"https://api.github.com/user/starred/{owner}/{repo}",
headers={"Content-Length": "0"},
)
return "Repository starred successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.star_repo(credentials, input_data.repo_url)
yield "status", status
except Exception as e:
yield "error", str(e)

View File

@@ -0,0 +1,452 @@
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
per_page: int = SchemaField(
description="Number of branches to return per page (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(
branches_url, params={"per_page": str(per_page), "page": str(page)}
)
data = response.json()
repo_path = github_repo_path(repo_url)
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = await self.list_branches(
credentials,
input_data.repo_url,
input_data.per_page,
input_data.page,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
except Exception as e:
yield "error", str(e)
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
is_sensitive_action=True,
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubCompareBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
base: str = SchemaField(
description="Base branch or commit SHA",
placeholder="main",
)
head: str = SchemaField(
description="Head branch or commit SHA to compare against base",
placeholder="feature-branch",
)
class Output(BlockSchemaOutput):
class FileChange(TypedDict):
filename: str
status: str
additions: int
deletions: int
patch: str
status: str = SchemaField(
description="Comparison status: ahead, behind, diverged, or identical"
)
ahead_by: int = SchemaField(
description="Number of commits head is ahead of base"
)
behind_by: int = SchemaField(
description="Number of commits head is behind base"
)
total_commits: int = SchemaField(
description="Total number of commits in the comparison"
)
diff: str = SchemaField(description="Unified diff of all file changes")
file: FileChange = SchemaField(
title="Changed File", description="A changed file with its diff"
)
files: list[FileChange] = SchemaField(
description="List of changed files with their diffs"
)
error: str = SchemaField(description="Error message if comparison failed")
def __init__(self):
super().__init__(
id="2e4faa8c-6086-4546-ba77-172d1d560186",
description="This block compares two branches or commits in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCompareBranchesBlock.Input,
output_schema=GithubCompareBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"base": "main",
"head": "feature",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("status", "ahead"),
("ahead_by", 2),
("behind_by", 0),
("total_commits", 2),
("diff", "+++ b/file.py\n+new line"),
(
"files",
[
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
),
(
"file",
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
},
),
],
test_mock={
"compare_branches": lambda *args, **kwargs: {
"status": "ahead",
"ahead_by": 2,
"behind_by": 0,
"total_commits": 2,
"files": [
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
}
},
)
@staticmethod
async def compare_branches(
credentials: GithubCredentials,
repo_url: str,
base: str,
head: str,
) -> dict:
api = get_api(credentials)
safe_base = quote(base, safe="")
safe_head = quote(head, safe="")
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
response = await api.get(compare_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.compare_branches(
credentials,
input_data.repo_url,
input_data.base,
input_data.head,
)
yield "status", data["status"]
yield "ahead_by", data["ahead_by"]
yield "behind_by", data["behind_by"]
yield "total_commits", data["total_commits"]
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
GithubCompareBranchesBlock.Output.FileChange(
filename=f["filename"],
status=f["status"],
additions=f["additions"],
deletions=f["deletions"],
patch=f.get("patch", ""),
)
for f in data.get("files", [])
]
# Build unified diff
diff_parts = []
for f in data.get("files", []):
patch = f.get("patch", "")
if patch:
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
yield "diff", "\n".join(diff_parts)
yield "files", files
for file in files:
yield "file", file
except Exception as e:
yield "error", str(e)

View File

@@ -0,0 +1,720 @@
import base64
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
error: str = SchemaField(description="Error message if reading the file failed")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = (
repo_url
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
)
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", str(e)
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to main)",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = (
repo_url
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
)
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
except Exception as e:
yield "error", str(e)
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubSearchCodeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
query: str = SchemaField(
description="Search query (GitHub code search syntax)",
placeholder="className language:python",
)
repo: str = SchemaField(
description="Restrict search to a repository (owner/repo format, optional)",
default="",
placeholder="owner/repo",
)
per_page: int = SchemaField(
description="Number of results to return (max 100)",
default=30,
ge=1,
le=100,
)
class Output(BlockSchemaOutput):
class SearchResult(TypedDict):
name: str
path: str
repository: str
url: str
score: float
result: SearchResult = SchemaField(
title="Result", description="A code search result"
)
results: list[SearchResult] = SchemaField(
description="List of code search results"
)
total_count: int = SchemaField(description="Total number of matching results")
error: str = SchemaField(description="Error message if search failed")
def __init__(self):
super().__init__(
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
description="This block searches for code in GitHub repositories.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSearchCodeBlock.Input,
output_schema=GithubSearchCodeBlock.Output,
test_input={
"query": "addClass",
"repo": "owner/repo",
"per_page": 30,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("total_count", 1),
(
"results",
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
),
(
"result",
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
},
),
],
test_mock={
"search_code": lambda *args, **kwargs: (
1,
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
)
},
)
@staticmethod
async def search_code(
credentials: GithubCredentials,
query: str,
repo: str,
per_page: int,
) -> tuple[int, list[Output.SearchResult]]:
api = get_api(credentials, convert_urls=False)
full_query = f"{query} repo:{repo}" if repo else query
params = {"q": full_query, "per_page": str(per_page)}
response = await api.get("https://api.github.com/search/code", params=params)
data = response.json()
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
GithubSearchCodeBlock.Output.SearchResult(
name=item["name"],
path=item["path"],
repository=item["repository"]["full_name"],
url=item["html_url"],
score=item["score"],
)
for item in data["items"]
]
return data["total_count"], results
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
total_count, results = await self.search_code(
credentials,
input_data.query,
input_data.repo,
input_data.per_page,
)
yield "total_count", total_count
yield "results", results
for result in results:
yield "result", result
except Exception as e:
yield "error", str(e)
class GithubGetRepositoryTreeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to get the tree from",
default="main",
)
recursive: bool = SchemaField(
description="Whether to recursively list the entire tree",
default=True,
)
class Output(BlockSchemaOutput):
class TreeEntry(TypedDict):
path: str
type: str
size: int
sha: str
entry: TreeEntry = SchemaField(
title="Tree Entry", description="A file or directory in the tree"
)
entries: list[TreeEntry] = SchemaField(
description="List of all files and directories in the tree"
)
truncated: bool = SchemaField(
description="Whether the tree was truncated due to size"
)
error: str = SchemaField(description="Error message if getting tree failed")
def __init__(self):
super().__init__(
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
description="This block lists the entire file tree of a GitHub repository recursively.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryTreeBlock.Input,
output_schema=GithubGetRepositoryTreeBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"recursive": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("truncated", False),
(
"entries",
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
),
(
"entry",
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
},
),
],
test_mock={
"get_tree": lambda *args, **kwargs: (
False,
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
)
},
)
@staticmethod
async def get_tree(
credentials: GithubCredentials,
repo_url: str,
branch: str,
recursive: bool,
) -> tuple[bool, list[Output.TreeEntry]]:
api = get_api(credentials)
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
params = {"recursive": "1"} if recursive else {}
response = await api.get(tree_url, params=params)
data = response.json()
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
GithubGetRepositoryTreeBlock.Output.TreeEntry(
path=item["path"],
type=item["type"],
size=item.get("size", 0),
sha=item["sha"],
)
for item in data["tree"]
]
return data.get("truncated", False), entries
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
truncated, entries = await self.get_tree(
credentials,
input_data.repo_url,
input_data.branch,
input_data.recursive,
)
yield "truncated", truncated
yield "entries", entries
for entry in entries:
yield "entry", entry
except Exception as e:
yield "error", str(e)

View File

@@ -0,0 +1,120 @@
import inspect
import pytest
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
class TestPreparePrApiUrl:
def test_https_scheme_preserved(self):
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
assert result == "https://github.com/owner/repo/pulls/42/merge"
def test_http_scheme_preserved(self):
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
assert result == "http://github.com/owner/repo/pulls/1/files"
def test_no_scheme_defaults_to_https(self):
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
assert result == "https://github.com/owner/repo/pulls/5/merge"
def test_reviewers_path(self):
result = prepare_pr_api_url(
"https://github.com/owner/repo/pull/99", "requested_reviewers"
)
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
def test_invalid_url_returned_as_is(self):
url = "https://example.com/not-a-pr"
assert prepare_pr_api_url(url, "merge") == url
def test_empty_string(self):
assert prepare_pr_api_url("", "merge") == ""
# ── Error-path block tests ──
# When a block's run() yields ("error", msg), _execute() converts it to a
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
# which returns early on empty test_output).
def _mock_block(block, mocks: dict):
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
for name, mock_fn in mocks.items():
original = getattr(block, name)
if inspect.iscoroutinefunction(original):
async def async_mock(*args, _fn=mock_fn, **kwargs):
return _fn(*args, **kwargs)
setattr(block, name, async_mock)
else:
setattr(block, name, mock_fn)
def _raise(exc: Exception):
"""Helper that returns a callable which raises the given exception."""
def _raiser(*args, **kwargs):
raise exc
return _raiser
@pytest.mark.asyncio
async def test_merge_pr_error_path():
block = GithubMergePullRequestBlock()
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
input_data = {
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_multi_file_commit_error_path():
block = GithubMultiFileCommitBlock()
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
input_data = {
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "test",
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
# ── FileOperation enum tests ──
class TestFileOperation:
def test_upsert_value(self):
assert FileOperation.UPSERT == "upsert"
def test_delete_value(self):
assert FileOperation.DELETE == "delete"
def test_invalid_value_raises(self):
with pytest.raises(ValueError):
FileOperation("create")
def test_invalid_value_raises_typo(self):
with pytest.raises(ValueError):
FileOperation("upser")

View File

@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
h.ignore_links = False
h.ignore_images = True
return h.handle(html_content)
except ImportError:
# Fallback: return raw HTML if html2text is not available
except Exception:
# Keep extraction resilient if html2text is unavailable or fails.
return html_content
# Handle content stored as attachment

View File

@@ -67,6 +67,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -143,10 +144,11 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -168,6 +170,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -197,6 +200,7 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,10 +31,14 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -136,19 +140,31 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
CODESTRAL = "mistralai/codestral-2508"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -156,9 +172,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
MICROSOFT_PHI_4 = "microsoft/phi-4"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_3 = "x-ai/grok-3"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
@@ -336,17 +354,41 @@ MODEL_METADATA = {
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
"open_router",
1050000,
8192,
1048576,
65536,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 2.5 Pro",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Pro Preview",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3 Flash Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
@@ -354,6 +396,15 @@ MODEL_METADATA = {
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Flash Lite Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
@@ -375,12 +426,78 @@ MODEL_METADATA = {
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
"open_router",
262144,
None,
"Mistral Large 3 2512",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
"open_router",
131072,
None,
"Mistral Medium 3.1",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
"open_router",
131072,
131072,
"Mistral Small 3.2 24B",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.CODESTRAL: ModelMetadata(
"open_router",
256000,
None,
"Codestral 2508",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Translate 08.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
"open_router",
256000,
32768,
"Command A Reasoning 08.2025",
"OpenRouter",
"Cohere",
3,
),
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Vision 07.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
@@ -393,6 +510,15 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
"open_router",
128000,
8000,
"Sonar Reasoning Pro",
"OpenRouter",
"Perplexity",
2,
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
@@ -438,6 +564,9 @@ MODEL_METADATA = {
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
@@ -447,6 +576,15 @@ MODEL_METADATA = {
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_3: ModelMetadata(
"open_router",
131072,
131072,
"Grok 3",
"OpenRouter",
"xAI",
2,
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
@@ -804,6 +942,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -825,7 +968,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr
from pydantic import SecretStr, field_validator
from backend.blocks._base import (
Block,
@@ -13,6 +13,7 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -21,6 +22,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -34,6 +36,20 @@ class PerplexityModel(str, Enum):
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
if isinstance(value, PerplexityModel):
return value
try:
return PerplexityModel(value)
except ValueError:
logger.warning(
f"Invalid PerplexityModel '{value}', "
f"falling back to {PerplexityModel.SONAR.value}"
)
return PerplexityModel.SONAR
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
@@ -72,6 +88,25 @@ class PerplexityBlock(Block):
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
@field_validator("model", mode="before")
@classmethod
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
"""Fall back to SONAR if the model value is not a valid
PerplexityModel (e.g. an OpenAI model ID set by the agent
generator)."""
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
system_prompt: str = SchemaField(
title="System Prompt",
default="",
@@ -136,7 +171,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -2232,6 +2232,7 @@ class DeleteRedditPostBlock(Block):
("post_id", "abc123"),
],
test_mock={"delete_post": lambda creds, post_id: True},
is_sensitive_action=True,
)
@staticmethod
@@ -2290,6 +2291,7 @@ class DeleteRedditCommentBlock(Block):
("comment_id", "xyz789"),
],
test_mock={"delete_comment": lambda creds, comment_id: True},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -72,6 +72,7 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
"_convert_to_color": lambda *args, **kwargs: "black",
},
is_sensitive_action=True,
)
async def run(

View File

@@ -0,0 +1,81 @@
"""Unit tests for PerplexityBlock model fallback behavior."""
import pytest
from backend.blocks.perplexity import (
TEST_CREDENTIALS_INPUT,
PerplexityBlock,
PerplexityModel,
)
def _make_input(**overrides) -> dict:
defaults = {
"prompt": "test query",
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return defaults
class TestPerplexityModelFallback:
"""Tests for fallback_invalid_model field_validator."""
def test_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
assert inp.model == PerplexityModel.SONAR
def test_another_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
assert inp.model == PerplexityModel.SONAR
def test_valid_model_string_is_kept(self):
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
assert inp.model == PerplexityModel.SONAR_PRO
def test_valid_enum_value_is_kept(self):
inp = PerplexityBlock.Input(
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
)
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
def test_default_model_when_omitted(self):
inp = PerplexityBlock.Input(**_make_input())
assert inp.model == PerplexityModel.SONAR
@pytest.mark.parametrize(
"model_value",
[
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
)
def test_all_valid_models_accepted(self, model_value: str):
inp = PerplexityBlock.Input(**_make_input(model=model_value))
assert inp.model.value == model_value
class TestPerplexityValidateData:
"""Tests for validate_data which runs during block execution (before
Pydantic instantiation). Invalid models must be sanitized here so
JSON schema validation does not reject them."""
def test_invalid_model_sanitized_before_schema_validation(self):
data = _make_input(model="gpt-5.2-2025-12-11")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == PerplexityModel.SONAR.value
def test_valid_model_unchanged_by_validate_data(self):
data = _make_input(model="perplexity/sonar-pro")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == "perplexity/sonar-pro"
def test_missing_model_uses_default(self):
data = _make_input() # no model key
error = PerplexityBlock.Input.validate_data(data)
assert error is None
inp = PerplexityBlock.Input(**data)
assert inp.model == PerplexityModel.SONAR

View File

@@ -18,11 +18,13 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -36,6 +38,7 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -46,7 +49,11 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
logger = logging.getLogger(__name__)
@@ -221,6 +228,9 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -232,6 +242,7 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -242,7 +253,18 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta:
continue
@@ -411,6 +433,53 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 0
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -421,4 +490,16 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -1,10 +1,13 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
default=OPENROUTER_BASE_URL,
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -67,6 +70,20 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -112,9 +129,37 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
@@ -164,7 +209,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
v = OPENROUTER_BASE_URL
return v
@field_validator("use_claude_agent_sdk", mode="before")

View File

@@ -0,0 +1,38 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -6,6 +6,32 @@
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
# Separator used in synthetic node_exec_id to encode node_id.
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
"""Extract node_id from a synthetic node_exec_id.
Format: "{node_id}:{random_hex}" → returns "{node_id}".
"""
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]

View File

@@ -0,0 +1,115 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -0,0 +1,163 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -73,6 +73,9 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -0,0 +1,138 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -0,0 +1,91 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -26,6 +26,38 @@ your message as a Markdown link or image:
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Type coercion**: The platform automatically coerces expanded string values
to match the block's expected input types. For example, if a block expects
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
an @@agptfile: expansion), the string will be parsed into the correct type.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.

View File

@@ -0,0 +1,253 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
pushes to *next* Monday so the current week's window stays open.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True
)
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except Exception:
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
exc_info=True,
)

View File

@@ -0,0 +1,334 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,12 +186,29 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
class StreamError(StreamBaseResponse):

View File

@@ -0,0 +1,155 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -198,6 +198,7 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
@@ -17,36 +15,23 @@ import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
def _get_sandbox():
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
return is_allowed_local_path(path, get_sdk_cwd())
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
remote = resolve_sandbox_path(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
with open(expanded, encoding="utf-8", errors="replace") as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)

View File

@@ -7,59 +7,60 @@ import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,546 @@
"""End-to-end compaction flow test.
Simulates the full service.py compaction lifecycle using real-format
JSONL session files — no SDK subprocess needed. Exercises:
1. TranscriptBuilder loads a "downloaded" transcript
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns end events
6. _read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
10. Fresh builder loads the export — roundtrip verified
"""
import asyncio
from pathlib import Path
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import strip_progress_entries
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
"""Test-only: read compacted entries from a session JSONL file.
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
entry onward, or ``None`` if no summary is found.
"""
content = Path(path).read_text()
lines = content.strip().split("\n")
compact_idx: int | None = None
parsed: list[dict] = []
raw_lines: list[str] = []
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
parsed.append(entry)
raw_lines.append(line.strip())
if compact_idx is None and entry.get("isCompactSummary"):
compact_idx = len(parsed) - 1
if compact_idx is None:
return None
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
# Pre-compaction conversation
USER_1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "What files are in this project?"},
}
ASST_1_THINKING = {
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
"stop_reason": None,
"stop_sequence": None,
},
}
ASST_1_TOOL = {
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {"command": "ls"},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
TOOL_RESULT_1 = {
"type": "user",
"uuid": "tr1",
"parentUuid": "a1-tool",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu1",
"content": "file1.py\nfile2.py",
}
],
},
}
ASST_1_TEXT = {
"type": "assistant",
"uuid": "a1-text",
"parentUuid": "tr1",
"message": {
"role": "assistant",
"id": "msg_sdk_bbb",
"type": "message",
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Progress entries (should be stripped during upload)
PROGRESS_1 = {
"type": "progress",
"uuid": "prog1",
"parentUuid": "a1-tool",
"data": {"type": "bash_progress", "stdout": "running ls..."},
}
# Second user message
USER_2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1-text",
"message": {"role": "user", "content": "Show me file1.py"},
}
ASST_2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_sdk_ccc",
"type": "message",
"content": [{"type": "text", "text": "Here is file1.py content..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# --- Compaction summary (written by CLI after context compaction) ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {
"role": "user",
"content": (
"Summary: User asked about project files. Found file1.py and file2.py. "
"User then asked to see file1.py."
),
},
}
# Post-compaction assistant response
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a3",
"parentUuid": "cs1",
"message": {
"role": "assistant",
"id": "msg_sdk_ddd",
"type": "message",
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Post-compaction user follow-up
USER_3 = {
"type": "user",
"uuid": "u3",
"parentUuid": "a3",
"message": {"role": "user", "content": "Now show file2.py"},
}
ASST_3 = {
"type": "assistant",
"uuid": "a4",
"parentUuid": "u3",
"message": {
"role": "assistant",
"id": "msg_sdk_eee",
"type": "message",
"content": [{"type": "text", "text": "Here is file2.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# ---------------------------------------------------------------------------
# E2E test
# ---------------------------------------------------------------------------
class TestCompactionE2E:
def _write_session_file(self, session_dir, entries):
"""Write a CLI session JSONL file."""
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path):
"""Simulate the complete service.py compaction flow.
Timeline:
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
2. Current turn: download → load_previous
3. User sends "Now show file2.py" → append_user
4. SDK starts streaming response
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (end events)
9. _read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 7
# --- Step 3: User sends new query ---
builder.append_user("Now show file2.py")
assert builder.entry_count == 8
# --- Step 4: SDK starts streaming ---
builder.append_assistant(
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
model="claude-sonnet-4-20250514",
)
assert builder.entry_count == 9
# --- Step 5-6: PreCompact fires, CLI writes session file ---
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
USER_3,
ASST_3,
],
)
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
# on_compact is a property returning Event.set callable
tracker.on_compact()
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
assert len(start_events) == 3
assert isinstance(start_events[0], StreamStartStep)
assert isinstance(start_events[1], StreamToolInputStart)
assert isinstance(start_events[2], StreamToolInputAvailable)
# Verify tool_call_id is set
tool_call_id = start_events[1].toolCallId
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
end_events = _run(tracker.emit_end_if_ready(session))
assert len(end_events) == 2
assert isinstance(end_events[0], StreamToolOutputAvailable)
assert isinstance(end_events[1], StreamFinishStep)
# Verify same tool_call_id
assert end_events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: _read_compacted_entries + replace_entries ---
result = _read_compacted_entries(str(session_file))
assert result is not None
compacted_dicts, compacted_jsonl = result
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted_dicts) == 4
assert compacted_dicts[0]["uuid"] == "cs1"
assert compacted_dicts[0]["isCompactSummary"] is True
# Replace builder state with compacted JSONL
old_count = builder.entry_count
builder.replace_entries(compacted_jsonl)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
# --- Step 11: More assistant messages after compaction ---
builder.append_assistant(
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
model="claude-sonnet-4-20250514",
stop_reason="end_turn",
)
assert builder.entry_count == 5
# --- Step 12: Export for upload ---
output = builder.to_jsonl()
assert output # Not empty
output_entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(output_entries) == 5
# Verify structure:
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
assert output_entries[0]["type"] == "summary"
assert output_entries[0].get("isCompactSummary") is True
assert output_entries[0]["uuid"] == "cs1"
assert output_entries[1]["uuid"] == "a3"
assert output_entries[2]["uuid"] == "u3"
assert output_entries[3]["uuid"] == "a4"
assert output_entries[4]["type"] == "assistant"
# Verify parent chain is intact
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
assert output_entries[4]["parentUuid"] == "a4" # new → a4
# --- Step 13: Roundtrip — next turn loads this export ---
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 5
# isCompactSummary survives roundtrip
output2 = builder2.to_jsonl()
first_entry = json.loads(output2.strip().split("\n")[0])
assert first_entry.get("isCompactSummary") is True
# Can append more messages
builder2.append_user("What about file3.py?")
assert builder2.entry_count == 6
final_output = builder2.to_jsonl()
last_entry = json.loads(final_output.strip().split("\n")[-1])
assert last_entry["type"] == "user"
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path):
"""Two compactions in the same session (across reset_for_query)."""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---
builder.append_user("first question")
builder.append_assistant([{"type": "text", "text": "first answer"}])
# Write session file for first compaction
first_summary = {
"type": "summary",
"uuid": "cs-first",
"isCompactSummary": True,
"message": {"role": "user", "content": "First compaction summary"},
}
first_post = {
"type": "assistant",
"uuid": "a-first",
"parentUuid": "cs-first",
"message": {"role": "assistant", "content": "first post-compact"},
}
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events1 = _run(tracker.emit_end_if_ready(session))
assert len(end_events1) == 2 # output + finish
result1_entries = _read_compacted_entries(str(file1))
assert result1_entries is not None
_, compacted1_jsonl = result1_entries
builder.replace_entries(compacted1_jsonl)
assert builder.entry_count == 2
# --- Reset for second query ---
tracker.reset_for_query()
# --- Second query with compaction ---
builder.append_user("second question")
builder.append_assistant([{"type": "text", "text": "second answer"}])
second_summary = {
"type": "summary",
"uuid": "cs-second",
"isCompactSummary": True,
"message": {"role": "user", "content": "Second compaction summary"},
}
second_post = {
"type": "assistant",
"uuid": "a-second",
"parentUuid": "cs-second",
"message": {"role": "assistant", "content": "second post-compact"},
}
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events2 = _run(tracker.emit_end_if_ready(session))
assert len(end_events2) == 2 # output + finish
result2_entries = _read_compacted_entries(str(file2))
assert result2_entries is not None
_, compacted2_jsonl = result2_entries
builder.replace_entries(compacted2_jsonl)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
Turn 1: SDK produces transcript with progress entries
Upload: strip_progress_entries removes progress, upload to cloud
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
# Strip progress for upload
stripped = strip_progress_entries(raw_content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
# Progress should be gone
assert not any(e.get("type") == "progress" for e in stripped_entries)
assert len(stripped_entries) == 7 # 8 - 1 progress
# --- Turn 2: Download stripped, load, compaction happens ---
builder = TranscriptBuilder()
builder.load_previous(stripped)
assert builder.entry_count == 7
builder.append_user("Now show file2.py")
builder.append_assistant(
[{"type": "text", "text": "Reading file2.py..."}],
model="claude-sonnet-4-20250514",
)
# CLI writes session file with compaction
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
],
)
result = _read_compacted_entries(str(session_file))
assert result is not None
_, compacted_jsonl = result
builder.replace_entries(compacted_jsonl)
# Append post-compaction message
builder.append_user("Thanks!")
output = builder.to_jsonl()
# --- Turn 3: Fresh load of Turn 2 export ---
builder3 = TranscriptBuilder()
builder3.load_previous(output)
# Should have: compact_summary + post_compact_asst + "Thanks!"
assert builder3.entry_count == 3
# Compact summary survived the full pipeline
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
assert first.get("isCompactSummary") is True
assert first["type"] == "summary"

View File

@@ -0,0 +1,281 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
async def _expand(value: Any) -> Any:
if isinstance(value, str):
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
if isinstance(value, list):
return [await _expand(item) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}

View File

@@ -0,0 +1,328 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

View File

@@ -0,0 +1,382 @@
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
_MAX_EXPAND_CHARS,
FileRef,
FileRefExpansionError,
_apply_line_range,
expand_file_refs_in_args,
expand_file_refs_in_string,
parse_file_ref,
)
# ---------------------------------------------------------------------------
# parse_file_ref
# ---------------------------------------------------------------------------
def test_parse_file_ref_workspace_id():
ref = parse_file_ref("@@agptfile:workspace://abc123")
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
def test_parse_file_ref_workspace_id_with_mime():
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
assert ref is not None
assert ref.uri == "workspace://abc123#text/plain"
assert ref.start_line is None
def test_parse_file_ref_workspace_path():
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
assert ref is not None
assert ref.uri == "workspace:///reports/q1.md"
def test_parse_file_ref_with_line_range():
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
def test_parse_file_ref_local_path():
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
assert ref is not None
assert ref.uri == "/tmp/copilot-session/output.py"
assert ref.start_line == 1
assert ref.end_line == 100
def test_parse_file_ref_no_match():
assert parse_file_ref("just a normal string") is None
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
assert (
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
) # not full match
def test_parse_file_ref_strips_whitespace():
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
assert ref is not None
assert ref.uri == "workspace://abc123"
def test_parse_file_ref_invalid_range_zero_start():
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
def test_parse_file_ref_invalid_range_end_less_than_start():
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
def test_parse_file_ref_invalid_range_zero_end():
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
# ---------------------------------------------------------------------------
# _apply_line_range
# ---------------------------------------------------------------------------
TEXT = "line1\nline2\nline3\nline4\nline5\n"
def test_apply_line_range_no_range():
assert _apply_line_range(TEXT, None, None) == TEXT
def test_apply_line_range_start_only():
result = _apply_line_range(TEXT, 3, None)
assert result == "line3\nline4\nline5\n"
def test_apply_line_range_full():
result = _apply_line_range(TEXT, 2, 4)
assert result == "line2\nline3\nline4\n"
def test_apply_line_range_single_line():
result = _apply_line_range(TEXT, 2, 2)
assert result == "line2\n"
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
# ---------------------------------------------------------------------------
# expand_file_refs_in_string
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "sess-1") -> MagicMock:
session = MagicMock()
session.session_id = session_id
return session
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
"""Stub resolver that returns the URI and range as a descriptive string."""
if ref.start_line is not None:
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
return f"content:{ref.uri}"
@pytest.mark.asyncio
async def test_expand_no_refs():
result = await expand_file_refs_in_string(
"no references here", user_id="u1", session=_make_session()
)
assert result == "no references here"
@pytest.mark.asyncio
async def test_expand_single_ref():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123"
@pytest.mark.asyncio
async def test_expand_ref_with_range():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[10-50]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123[10-50]"
@pytest.mark.asyncio
async def test_expand_ref_embedded_in_text():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"Here is the file: @@agptfile:workspace://abc123 — done",
user_id="u1",
session=_make_session(),
)
assert result == "Here is the file: content:workspace://abc123 — done"
@pytest.mark.asyncio
async def test_expand_multiple_refs():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
@pytest.mark.asyncio
async def test_expand_invalid_range_zero_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[0-5]",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "line numbers must be >= 1" in result
@pytest.mark.asyncio
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
result = await expand_file_refs_in_string(
"prefix @@agptfile:workspace://abc123[10-5] suffix",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "end line must be >= start line" in result
assert "prefix" in result
assert "suffix" in result
@pytest.mark.asyncio
async def test_expand_ref_error_surfaces_inline():
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("file not found")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://bad",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "file not found" in result
# ---------------------------------------------------------------------------
# expand_file_refs_in_args
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_flat():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"content": "@@agptfile:workspace://abc123", "other": 42},
user_id="u1",
session=_make_session(),
)
assert result["content"] == "content:workspace://abc123"
assert result["other"] == 42
@pytest.mark.asyncio
async def test_expand_args_nested_dict():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"outer": {"inner": "@@agptfile:workspace://nested"}},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["inner"] == "content:workspace://nested"
@pytest.mark.asyncio
async def test_expand_args_list():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{
"items": [
"@@agptfile:workspace://a",
"plain",
"@@agptfile:workspace://b[1-3]",
]
},
user_id="u1",
session=_make_session(),
)
assert result["items"] == [
"content:workspace://a",
"plain",
"content:workspace://b[1-3]",
]
@pytest.mark.asyncio
async def test_expand_args_empty():
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
assert result == {}
@pytest.mark.asyncio
async def test_expand_args_no_refs():
result = await expand_file_refs_in_args(
{"key": "no refs here", "num": 1},
user_id="u1",
session=_make_session(),
)
assert result == {"key": "no refs here", "num": 1}
@pytest.mark.asyncio
async def test_expand_args_raises_on_file_ref_error():
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
the inline error string to the tool, blocking tool execution."""
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("path does not exist")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
with pytest.raises(FileRefExpansionError) as exc_info:
await expand_file_refs_in_args(
{"prompt": "@@agptfile:/home/user/missing.txt"},
user_id="u1",
session=_make_session(),
)
assert "path does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Per-file truncation and aggregate budget
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_per_file_truncation():
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
return oversized
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_oversized),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://big-file",
user_id="u1",
session=_make_session(),
)
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
assert "[truncated]" in result
@pytest.mark.asyncio
async def test_expand_aggregate_budget_exhausted():
"""When the aggregate budget is exhausted, later refs get the budget message."""
# Each file returns just under 300K; after ~4 files the 1M budget is used.
big_chunk = "y" * 300_000
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
return big_chunk
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_big),
):
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
result = await expand_file_refs_in_string(
refs,
user_id="u1",
session=_make_session(),
)
assert "budget exhausted" in result

View File

@@ -0,0 +1,42 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
### Communication style
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
Use plain, friendly language instead:
| Instead of… | Say… |
|---|---|
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".

View File

@@ -221,12 +221,12 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
return responses

View File

@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)

View File

@@ -10,12 +10,13 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -51,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -70,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
logger.warning("Blocked tool access attempt: %s", tool_name)
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -110,7 +111,9 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -168,7 +171,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
logger.info("[SDK] Blocked background Task, user=%s", user_id)
return cast(
SyncHookJSONOutput,
_deny(
@@ -210,7 +213,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:

View File

@@ -9,8 +9,9 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
assert _is_denied(result)
finally:
_current_project_dir.reset(token)
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

View File

@@ -40,11 +40,13 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -54,13 +56,14 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
@@ -75,8 +78,12 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
COMPACT_THRESHOLD_BYTES,
TranscriptDownload,
cleanup_cli_project_dir,
compact_transcript,
download_transcript,
read_cli_session_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -294,7 +301,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
return
# Clean the CLI's project directory (transcripts + tool-results).
@@ -388,7 +395,7 @@ async def _compress_messages(
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
logger.warning("[SDK] Context compression with LLM failed: %s", e)
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
@@ -456,31 +463,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
lower = content.lower()
return any(
marker in lower
for marker in (
"[security]",
"cannot be bypassed",
"not allowed",
"not supported", # background-task denial
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
async def _build_query_message(
current_message: str,
session: ChatSession,
@@ -649,6 +631,56 @@ async def _prepare_file_attachments(
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
async def _maybe_compact_and_upload(
dl: TranscriptDownload,
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> str:
"""Compact an oversized transcript and upload the compacted version.
Returns the (possibly compacted) transcript content, or an empty string
if compaction was needed but failed.
"""
content = dl.content
if len(content) <= COMPACT_THRESHOLD_BYTES:
return content
logger.warning(
"%s Transcript oversized (%dB > %dB), compacting",
log_prefix,
len(content),
COMPACT_THRESHOLD_BYTES,
)
compacted = await compact_transcript(content, log_prefix=log_prefix)
if not compacted:
logger.warning(
"%s Compaction failed, skipping resume for this turn", log_prefix
)
return ""
# Keep the original message_count: it reflects the number of
# session.messages covered by this transcript, which the gap-fill
# logic uses as a slice index. Counting JSONL lines would give a
# smaller number (compacted messages != session message count) and
# cause already-covered messages to be re-injected.
try:
await upload_transcript(
user_id=user_id,
session_id=session_id,
content=compacted,
message_count=dl.message_count,
log_prefix=log_prefix,
)
except Exception:
logger.warning(
"%s Failed to upload compacted transcript",
log_prefix,
exc_info=True,
)
return compacted
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -760,6 +792,14 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -784,28 +824,29 @@ async def stream_chat_completion_sdk(
async def _setup_e2b():
"""Set up E2B sandbox if configured, return sandbox or None."""
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
return None
if config.use_e2b_sandbox and config.e2b_api_key:
try:
return await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
if not (e2b_api_key := config.active_e2b_api_key):
if config.use_e2b_sandbox:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
try:
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
on_timeout=config.e2b_sandbox_on_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
async def _fetch_transcript():
@@ -837,7 +878,6 @@ async def stream_chat_completion_sdk(
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
@@ -852,20 +892,33 @@ async def stream_chat_completion_sdk(
is_valid,
)
if is_valid:
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
transcript_content = await _maybe_compact_and_upload(
dl,
user_id=user_id or "",
session_id=session_id,
log_prefix=log_prefix,
)
# Load previous context into builder (empty string is a no-op)
if transcript_content:
transcript_builder.load_previous(
transcript_content, log_prefix=log_prefix
)
resume_file = (
write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if transcript_content
else None
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
logger.warning("%s Transcript downloaded but invalid", log_prefix)
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"{log_prefix} No transcript available "
@@ -902,6 +955,11 @@ async def stream_chat_completion_sdk(
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
sid = session_id[:12] if session_id else "?"
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
@@ -910,6 +968,7 @@ async def stream_chat_completion_sdk(
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
"stderr": _on_stderr,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
@@ -1082,6 +1141,19 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1116,7 +1188,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
# Log ResultMessage details and capture token usage
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1135,9 +1207,46 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with
# the CLI's compacted session file so the uploaded
# transcript reflects compaction.
compaction_events = await compaction.emit_end_if_ready(session)
for ev in compaction_events:
yield ev
if compaction_events and sdk_cwd:
cli_content = await read_cli_session_file(sdk_cwd)
if cli_content:
transcript_builder.replace_entries(
cli_content, log_prefix=log_prefix
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1331,6 +1440,27 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1395,6 +1525,48 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
@@ -1416,6 +1588,17 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
# --- Pause E2B sandbox to stop billing between turns ---
# Fire-and-forget: pausing is best-effort and must not block the
# response or the transcript upload. The task is anchored to
# _background_tasks to prevent garbage collection.
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
# round-trip — e2b_sandbox is the live object from this turn.
if e2b_sandbox is not None:
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception.
@@ -1479,6 +1662,6 @@ async def _update_title_async(
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")
logger.warning("[SDK] Failed to update session title: %s", e)

View File

@@ -1,9 +1,10 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -212,7 +213,7 @@ class TestPromptSupplement:
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
@@ -231,6 +232,48 @@ class TestPromptSupplement:
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement

View File

@@ -9,14 +9,29 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -149,24 +93,6 @@ def set_execution_context(
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -304,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
@@ -320,42 +252,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a local file with optional offset/limit.
"""Read a file with optional offset/limit.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
return _mcp_ok("".join(selected))
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
return _mcp_err(f"File not found: {file_path}")
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
return _mcp_err(f"Error reading file: {e}")
_READ_TOOL_NAME = "Read"
@@ -414,9 +354,23 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)

View File

@@ -2,12 +2,12 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,

View File

@@ -13,10 +13,17 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
logger = logging.getLogger(__name__)
@@ -34,6 +41,11 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
@dataclass
class TranscriptDownload:
@@ -82,7 +94,11 @@ def strip_progress_entries(content: str) -> str:
parent = entry.get("parentUuid", "")
if uid:
uuid_to_parent[uid] = parent
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
if (
entry.get("type", "") in STRIPPABLE_TYPES
and uid
and not entry.get("isCompactSummary")
):
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
@@ -93,7 +109,9 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -106,7 +124,9 @@ def strip_progress_entries(content: str) -> str:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES:
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
uid = entry.get("uuid", "")
if uid in reparented:
@@ -137,32 +157,78 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
Returns ``None`` if the path would escape the projects base.
"""
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
return None
return project_dir
async def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any mid-stream compaction.
After the CLI compacts context, its session file contains the compacted
conversation. Reading this file lets ``TranscriptBuilder`` replace its
uncompacted entries with the CLI's compacted version.
"""
import aiofiles
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
if not jsonl_files:
logger.debug("[Transcript] No CLI session file in %s", project_dir)
return None
# Pick the most recently modified file (there should only be one per turn).
# Guard against races where a file is deleted between glob and stat.
candidates: list[tuple[float, Path]] = []
for p in jsonl_files:
try:
candidates.append((p.stat().st_mtime, p))
except OSError:
continue
if not candidates:
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
return None
# Resolve + prefix check to prevent symlink escapes.
session_file = max(candidates, key=lambda item: item[0])[1]
real_path = str(session_file.resolve())
if not real_path.startswith(project_dir + os.sep):
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
return None
try:
async with aiofiles.open(real_path) as f:
content = await f.read()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
real_path,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory."""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
else:
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
logger.debug("[Transcript] Project dir not found: %s", project_dir)
def write_transcript_to_tempfile(
@@ -180,7 +246,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
return None
try:
@@ -190,17 +256,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
logger.warning("[Transcript] Failed to write resume file: %s", e)
return None
@@ -344,11 +410,14 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.info(
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
)
@@ -371,10 +440,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"{log_prefix} No transcript in storage")
logger.debug("%s No transcript in storage", log_prefix)
return None
except Exception as e:
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -394,10 +463,14 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, Exception):
except FileNotFoundError:
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -405,15 +478,171 @@ async def download_transcript(
)
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
# ---------------------------------------------------------------------------
# Transcript compaction
# ---------------------------------------------------------------------------
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
# Transcripts above this byte threshold are compacted at download time.
COMPACT_THRESHOLD_BYTES = 400_000
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format."""
lines: list[str] = []
last_uuid: str | None = None
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
async def _run_compression(
messages: list[dict],
model: str,
cfg: ChatConfig,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback."""
try:
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
except Exception as e:
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
async def compact_transcript(
content: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Returns the compacted JSONL string, or ``None`` on failure.
"""
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
if not result.was_compacted:
logger.info("%s Transcript already within token budget", log_prefix)
return content
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None

View File

@@ -31,6 +31,7 @@ class TranscriptEntry(BaseModel):
uuid: str
parentUuid: str | None
message: dict[str, Any]
isCompactSummary: bool | None = None
class TranscriptBuilder:
@@ -78,10 +79,12 @@ class TranscriptBuilder:
)
continue
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
# Compaction summaries may have type "summary" but must be preserved
# so --resume can reconstruct the compacted conversation.
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
is_compact = data.get("isCompactSummary", False)
if entry_type in STRIPPABLE_TYPES and not is_compact:
continue
entry = TranscriptEntry(
@@ -89,6 +92,7 @@ class TranscriptBuilder:
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
isCompactSummary=True if is_compact else None,
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -177,6 +181,33 @@ class TranscriptBuilder:
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Replace all entries with compacted JSONL content.
Called after the CLI performs mid-stream compaction so the builder's
state reflects the compacted conversation instead of the full
pre-compaction history.
"""
prev_count = len(self._entries)
temp = TranscriptBuilder()
try:
temp.load_previous(content, log_prefix=log_prefix)
except Exception:
logger.exception(
"%s Failed to parse compacted transcript; keeping %d existing entries",
log_prefix,
prev_count,
)
return
self._entries = temp._entries
self._last_uuid = temp._last_uuid
logger.info(
"%s Replaced %d entries with %d compacted entries",
log_prefix,
prev_count,
len(self._entries),
)
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""

View File

@@ -2,14 +2,25 @@
import os
import pytest
from backend.util import json
from .transcript import (
COMPACT_MSG_ID_PREFIX,
STRIPPABLE_TYPES,
_cli_project_dir,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
read_cli_session_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
def _make_jsonl(*entries: dict) -> str:
@@ -35,6 +46,14 @@ PROGRESS_ENTRY = {
"data": {"type": "bash_progress", "stdout": "running..."},
}
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"parentUuid": None,
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous conversation..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
@@ -237,6 +256,121 @@ class TestStripProgressEntries:
# Should return just a newline (empty content stripped)
assert result.strip() == ""
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
result = _cli_project_dir("/tmp/copilot-abc")
assert result is not None
assert "projects" in result
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
# A cwd that encodes to something with .. shouldn't escape
result = _cli_project_dir("/tmp/copilot-test")
# Should return a valid path (no traversal possible with alphanum encoding)
assert result is None or result.startswith(str(projects))
# --- read_cli_session_file ---
class TestReadCliSessionFile:
@pytest.mark.asyncio
async def test_reads_session_file(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
# Create the CLI project directory structure
cwd = "/tmp/copilot-testread"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# Write a session file
session_file = project_dir / "test-session.jsonl"
session_file.write_text(json.dumps(ASST_MSG) + "\n")
result = await read_cli_session_file(cwd)
assert result is not None
assert "assistant" in result
@pytest.mark.asyncio
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
cwd = "/tmp/copilot-nofiles"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# No jsonl files
result = await read_cli_session_file(cwd)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
(tmp_path / "projects").mkdir()
result = await read_cli_session_file("/tmp/copilot-nonexistent")
assert result is None
# --- _transcript_to_messages / _messages_to_transcript ---
class TestTranscriptMessageConversion:
def test_roundtrip_preserves_roles(self):
transcript = _make_jsonl(USER_MSG, ASST_MSG)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
def test_messages_to_transcript_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result) is True
def test_strips_strippable_types(self):
transcript = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
USER_MSG,
ASST_MSG,
)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2 # progress entry skipped
def test_flattens_assistant_content_blocks(self):
asst_with_blocks = {
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "hello"},
{"type": "tool_use", "name": "bash"},
],
},
}
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
assert len(messages) == 1
assert "hello" in messages[0]["content"]
assert "[tool_use: bash]" in messages[0]["content"]
def test_empty_messages_returns_empty(self):
result = _messages_to_transcript([])
assert result == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
@@ -282,3 +416,654 @@ class TestStripProgressEntries:
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented
# --- TranscriptBuilder ---
class TestTranscriptBuilderReplaceEntries:
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
def test_replace_entries_with_valid_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
builder.append_assistant([{"type": "text", "text": "world"}])
assert builder.entry_count == 2
# Replace with compacted content (one user + one assistant)
compacted = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(compacted)
assert builder.entry_count == 2
def test_replace_entries_keeps_old_on_corrupt_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
# Corrupt content that fails to parse
builder.replace_entries("not valid json at all\n")
# Should still have old entries (load_previous skips invalid lines,
# but if ALL lines are invalid, temp builder is empty → exception path)
assert builder.entry_count >= 0 # doesn't crash
def test_replace_entries_with_empty_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
builder.replace_entries("")
# Empty content → load_previous returns early → temp is empty
# replace_entries swaps to empty (0 entries)
assert builder.entry_count == 0
def test_replace_entries_filters_strippable_types(self):
"""Strippable types (progress, file-history-snapshot) are filtered out."""
builder = TranscriptBuilder()
builder.append_user("hello")
content = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {}},
USER_MSG,
ASST_MSG,
)
builder.replace_entries(content)
# Only user + assistant should remain (progress filtered)
assert builder.entry_count == 2
def test_replace_entries_preserves_uuids(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(content)
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
first = json.loads(lines[0])
assert first["uuid"] == "u1"
class TestTranscriptBuilderBasic:
def test_append_user_and_assistant(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "hello"}])
assert builder.entry_count == 2
assert not builder.is_empty
def test_to_jsonl_empty(self):
builder = TranscriptBuilder()
assert builder.to_jsonl() == ""
assert builder.is_empty
def test_load_previous_and_append(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.load_previous(content)
assert builder.entry_count == 2
builder.append_user("new message")
assert builder.entry_count == 3
def test_consecutive_assistant_entries_share_message_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "part1"}])
builder.append_assistant([{"type": "text", "text": "part2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[2])
assert asst1["message"]["id"] == asst2["message"]["id"]
def test_non_consecutive_assistant_entries_get_new_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "response1"}])
builder.append_user("followup")
builder.append_assistant([{"type": "text", "text": "response2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[3])
assert asst1["message"]["id"] != asst2["message"]["id"]
class TestCompactSummaryRoundtrip:
"""Verify isCompactSummary survives export→reload roundtrip."""
def test_load_previous_preserves_compact_summary(self):
"""Compaction summary with type 'summary' should not be stripped."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
assert builder.entry_count == 3
def test_export_reload_preserves_compact_summary(self):
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder1 = TranscriptBuilder()
builder1.load_previous(content)
assert builder1.entry_count == 3
exported = builder1.to_jsonl()
# Verify isCompactSummary is in the exported JSONL
first_line = json.loads(exported.strip().split("\n")[0])
assert first_line.get("isCompactSummary") is True
# Reload and verify it's still preserved
builder2 = TranscriptBuilder()
builder2.load_previous(exported)
assert builder2.entry_count == 3
def test_strip_progress_preserves_compact_summary(self):
"""strip_progress_entries should keep isCompactSummary entries."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
compact = [e for e in entries if e.get("isCompactSummary")]
assert len(compact) == 1
def test_regular_summary_still_stripped(self):
"""Non-compact summaries should still be stripped."""
regular_summary = {
"type": "summary",
"uuid": "rs1",
"summary": "Session summary",
}
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" not in types
def test_replace_entries_preserves_compact_summary(self):
"""replace_entries should preserve isCompactSummary entries."""
builder = TranscriptBuilder()
builder.append_user("old")
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder.replace_entries(content)
assert builder.entry_count == 3
# Verify by re-exporting
exported = builder.to_jsonl()
first = json.loads(exported.strip().split("\n")[0])
assert first.get("isCompactSummary") is True
# --- _flatten_assistant_content ---
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: read]" in result
def test_string_blocks(self):
"""Plain strings in the list should be included."""
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_empty_list(self):
assert _flatten_assistant_content([]) == ""
def test_tool_use_missing_name(self):
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
# --- _flatten_tool_result_content ---
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
]
assert _flatten_tool_result_content(blocks) == "simple result"
def test_tool_result_with_nested_list(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [
{"type": "text", "text": "line 1"},
{"type": "text", "text": "line 2"},
],
}
]
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
def test_text_blocks(self):
blocks = [{"type": "text", "text": "some text"}]
assert _flatten_tool_result_content(blocks) == "some text"
def test_string_items(self):
assert _flatten_tool_result_content(["raw string"]) == "raw string"
def test_empty_list(self):
assert _flatten_tool_result_content([]) == ""
def test_tool_result_none_text_uses_json(self):
"""Dicts without text key fall back to json.dumps."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback includes the key
# --- _transcript_to_messages ---
class TestTranscriptToMessages:
def test_basic_conversation(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hi there"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0] == {"role": "user", "content": "hello"}
assert msgs[1] == {"role": "assistant", "content": "hi there"}
def test_strips_progress_entries(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "progress",
"uuid": "p1",
"message": {"role": "user", "content": "..."},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "ok"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["role"] == "user"
assert msgs[1]["role"] == "assistant"
def test_preserves_compact_summaries(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous..."},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of previous..."
def test_strips_regular_summary(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "s1",
"message": {"role": "user", "content": "Session summary"},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert msgs[0]["content"] == "hi"
def test_skips_entries_without_role(self):
content = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {}},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
def test_tool_result_content(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "file contents",
}
],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert "file contents" in msgs[0]["content"]
def test_empty_content(self):
assert _transcript_to_messages("") == []
assert _transcript_to_messages(" \n ") == []
def test_invalid_json_lines_skipped(self):
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
# --- _messages_to_transcript ---
class TestMessagesToTranscript:
def test_basic_roundtrip_structure(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
result = _messages_to_transcript(messages)
assert result.endswith("\n")
lines = [json.loads(line) for line in result.strip().split("\n")]
assert len(lines) == 2
# User entry
assert lines[0]["type"] == "user"
assert lines[0]["message"]["role"] == "user"
assert lines[0]["message"]["content"] == "hello"
assert lines[0]["parentUuid"] is None
# Assistant entry
assert lines[1]["type"] == "assistant"
assert lines[1]["message"]["role"] == "assistant"
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
assert lines[1]["parentUuid"] == lines[0]["uuid"]
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
]
result = _messages_to_transcript(messages)
lines = [json.loads(line) for line in result.strip().split("\n")]
assert lines[0]["parentUuid"] is None
assert lines[1]["parentUuid"] == lines[0]["uuid"]
assert lines[2]["parentUuid"] == lines[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_assistant_empty_content(self):
messages = [{"role": "assistant", "content": ""}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["message"]["content"] == []
def test_output_is_valid_transcript(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
class TestTranscriptCompactionRoundtrip:
def test_content_preserved_through_roundtrip(self):
"""Messages→transcript→messages preserves content."""
original = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "Thanks"},
]
transcript = _messages_to_transcript(original)
recovered = _transcript_to_messages(transcript)
assert len(recovered) == len(original)
for orig, rec in zip(original, recovered):
assert orig["role"] == rec["role"]
assert orig["content"] == rec["content"]
def test_full_transcript_to_messages_and_back(self):
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "explain python"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "Python is a programming language."}
],
},
},
{
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "output of ls",
}
],
},
},
)
msgs1 = _transcript_to_messages(source)
assert len(msgs1) == 3
rebuilt = _messages_to_transcript(msgs1)
msgs2 = _transcript_to_messages(rebuilt)
assert len(msgs2) == len(msgs1)
for m1, m2 in zip(msgs1, msgs2):
assert m1["role"] == m2["role"]
# Content may differ in format (list vs string) but text is preserved
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
# --- compact_transcript ---
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""Transcripts with < 2 messages can't be compacted."""
single = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
)
result = await compact_transcript(single)
assert result is None
@pytest.mark.asyncio
async def test_empty_transcript_returns_none(self):
result = await compact_transcript("")
assert result is None
@pytest.mark.asyncio
async def test_compaction_produces_valid_transcript(self, monkeypatch):
"""When compress_context compacts, result should be valid JSONL."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[
{"role": "user", "content": "Summary of conversation"},
{"role": "assistant", "content": "Acknowledged"},
],
token_count=50,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=5,
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "msg1"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "reply1"}],
},
},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "msg2"},
},
)
result = await compact_transcript(source)
assert result is not None
assert validate_transcript(result)
# Verify compacted content
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of conversation"
@pytest.mark.asyncio
async def test_no_compaction_needed_returns_original(self, monkeypatch):
"""When compress_context says no compaction needed, return original."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[], token_count=100, was_compacted=False, original_token_count=100
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result == source # Unchanged
@pytest.mark.asyncio
async def test_compression_failure_returns_none(self, monkeypatch):
"""When _run_compression raises, compact_transcript returns None."""
from unittest.mock import AsyncMock
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result is None

View File

@@ -23,6 +23,11 @@ from typing import Any, Literal
import orjson
from backend.api.model import CopilotCompletionPayload
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
@@ -38,6 +43,7 @@ from .response_model import (
logger = logging.getLogger(__name__)
config = ChatConfig()
_notification_bus = AsyncRedisNotificationEventBus()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_sessions: dict[str, asyncio.Task] = {}
@@ -745,6 +751,29 @@ async def mark_session_completed(
# Clean up local session reference if exists
_local_sessions.pop(session_id, None)
# Publish copilot completion notification via WebSocket
if meta:
parsed = _parse_session_meta(meta, session_id)
if parsed.user_id:
try:
await _notification_bus.publish(
NotificationEvent(
user_id=parsed.user_id,
payload=CopilotCompletionPayload(
type="copilot_completion",
event="session_completed",
session_id=session_id,
status=status,
),
)
)
except Exception as e:
logger.warning(
f"Failed to publish copilot completion notification "
f"for session {session_id}: {e}"
)
return True

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -19,7 +20,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
@@ -32,6 +36,7 @@ from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -64,10 +69,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
@@ -80,6 +88,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url (backend.util.request)
# SSRF protection via shared validate_url_host (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url for SSRF protection.
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
We mock validate_url itself (not the low-level socket) so these tests
We mock validate_url_host itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
mock_validate.assert_called_once_with("https://example.com")
# ---------------------------------------------------------------------------

View File

@@ -1,20 +1,15 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -27,25 +22,20 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
from .validation import AgentFixer, AgentValidator
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentFixer",
"AgentValidator",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -54,7 +44,6 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -0,0 +1,66 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
from .helpers import UUID_RE_STR
logger = logging.getLogger(__name__)
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -792,70 +692,3 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -1,165 +0,0 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More