Compare commits

..

34 Commits

Author SHA1 Message Date
Nicholas Tindle
6aed43d708 fix(classic): register finish result before task continuation
AgentFinished is caught before execute() registers a result, leaving
the finish episode with result=None. The interaction loop sees this as
"episode in progress" and reuses the old finish proposal instead of
calling the LLM for the new task. Register a success result before
continuing so the loop calls propose_action() for the new task.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 17:45:30 +02:00
Nicholas Tindle
17e1578c46 feat(classic): preserve action history across task continuations
Stop clearing episodes when the user enters a new task after finishing.
The compression system (4 recent episodes full, older ones summarized,
1024 token budget) already handles context overflow. Keeping history
lets the agent build on prior work instead of starting from zero.

Restart the process for a clean slate.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 18:36:23 +02:00
Krzysztof Czerwinski
09e42041ce fix(frontend): AutoPilot notification follow-ups — branding, UX, persistence, and cross-tab sync (#12428)
AutoPilot (copilot) notifications had several follow-up issues after
initial implementation: old "Otto" branding, UX quirks, a service-worker
crash, notification state that didn't persist or sync across tabs, a
broken notification sound, and noisy Sentry alerts from SSR.

### Changes 🏗️

- **Rename "Otto" → "AutoPilot"** in all notification surfaces: browser
notifications, document title badge, permission dialog copy, and
notification banner copy
- **Agent Activity icon**: changed from `Bell` to `Pulse` (Phosphor) in
the navbar dropdown
- **Centered dialog buttons**: the "Stay in the loop" permission dialog
buttons are now centered instead of right-aligned
- **Service worker notification fix**: wrapped `new Notification()` in
try-catch so it degrades gracefully in service worker / PWA contexts
instead of throwing `TypeError: Illegal constructor`
- **Persist notification state**: `completedSessionIDs` is now stored in
localStorage (`copilot-completed-sessions`) so it survives page
refreshes and new tabs
- **Cross-tab sync**: a `storage` event listener keeps
`completedSessionIDs` and `document.title` in sync across all open tabs
— clearing a notification in one tab clears it everywhere
- **Fix notification sound**: corrected the sound file path from
`/sounds/notification.mp3` to `/notification.mp3` and added a
`.gitignore` exception (root `.gitignore` has a blanket `*.mp3` ignore
rule from legacy AutoGPT agent days)
- **Fix SSR Sentry noise**: guarded the Copilot Zustand store
initialization with a client-side check so `storage.get()` is never
called during SSR, eliminating spurious Sentry alerts (BUILDER-7CB, 7CC,
7C7) while keeping the Sentry reporting in `local-storage.ts` intact for
genuinely unexpected SSR access

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verify "AutoPilot" appears (not "Otto") in browser notification,
document title, permission dialog, and banner
  - [x] Verify Pulse icon in navbar Agent Activity dropdown
  - [x] Verify "Stay in the loop" dialog buttons are centered
- [x] Open two tabs on copilot → trigger completion → both tabs show
badge/checkmark
  - [x] Click completed session in tab 1 → badge clears in both tabs
  - [x] Refresh a tab → completed session state is preserved
  - [x] Verify notification sound plays on completion
  - [x] Verify no Sentry alerts from SSR localStorage access

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 11:44:22 +00:00
Zamil Majdy
a50e95f210 feat(backend/copilot): add include_graph option to find_library_agent (#12622)
## Why

The copilot's `edit_agent` tool requires the LLM to provide a complete
agent JSON (all nodes + links), but the LLM had **no way to see the
current graph structure** before editing. It was editing blindly —
guessing/hallucinating the entire node+link structure and replacing the
graph wholesale.

## What

- Add `include_graph` boolean parameter (default `false`) to the
existing `find_library_agent` tool
- When `true`, each returned `AgentInfo` includes a `graph` field with
the full graph JSON (nodes, links, `input_default` values)
- Update the agent generation guide to instruct the LLM to always fetch
the current graph before editing

## How

- Added `graph: dict[str, Any] | None` field to `AgentInfo` model
- Added `_enrich_agents_with_graph()` helper in `agent_search.py` that
calls the existing `get_agent_as_json()` utility to fetch full graph
data
- Threaded `include_graph` parameter through `find_library_agent` →
`search_agents` → `_search_library`
- Updated `agent_generation_guide.md` to add an "if editing" step that
fetches the graph first

No new tools introduced — reuses existing `find_library_agent` with one
optional flag.

## Test plan

- [x] Unit tests: 2 new tests added
(`test_include_graph_fetches_nodes_and_links`,
`test_include_graph_false_does_not_fetch`)
- [x] All 7 `agent_search_test.py` tests pass
- [x] All pre-commit hooks pass (lint, format, typecheck)
- [ ] Verify copilot correctly uses `include_graph=true` before editing
an agent (manual test)
2026-04-03 11:20:57 +00:00
Zamil Majdy
92b395d82a fix(backend): use OpenRouter client for simulator to support non-OpenAI models (#12656)
## Why

Dry-run block simulation is failing in production with `404 - model
gemini-2.5-flash does not exist`. The simulator's default model
(`google/gemini-2.5-flash`) is a non-OpenAI model that requires
OpenRouter routing, but the shared `get_openai_client()` prefers the
direct OpenAI key, creating a client that can't handle non-OpenAI
models. The old code also stripped the provider prefix, sending
`gemini-2.5-flash` to OpenAI's API.

## What

- Added `prefer_openrouter` keyword parameter to `get_openai_client()` —
when True, prefers the OpenRouter key (returns None if unavailable,
rather than falling back to an incompatible direct OpenAI client)
- Simulator now calls `get_openai_client(prefer_openrouter=True)` so
`google/gemini-2.5-flash` routes correctly through OpenRouter
- Removed the redundant `SIMULATION_MODEL` env var override and the
now-unnecessary provider prefix stripping from `_simulator_model()`

## How

`get_openai_client()` is decorated with `@cached(ttl_seconds=3600)`
which keys by args, so `get_openai_client()` and
`get_openai_client(prefer_openrouter=True)` are cached independently.
When `prefer_openrouter=True` and no OpenRouter key exists, returns
`None` instead of falling back — the simulator already handles `None`
with a clear error message.

### Checklist
- [x] All 24 dry-run tests pass
- [x] Test asserts `get_openai_client` is called with
`prefer_openrouter=True`
- [x] Format, lint, and pyright pass
- [x] No changes to user-facing APIs
- [ ] Deploy to staging and verify simulation works

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-03 11:19:09 +00:00
Ubbe
86abfbd394 feat(frontend): redesign onboarding wizard with Autopilot-first flow (#12640)
### Why / What / How

<img width="800" height="827" alt="Screenshot 2026-04-02 at 15 40 24"
src="https://github.com/user-attachments/assets/69a381c1-2884-434b-9406-4a3f7eec87cf"
/>
<img width="800" height="825" alt="Screenshot 2026-04-02 at 15 40 41"
src="https://github.com/user-attachments/assets/c6191a68-a8ba-482b-ba47-c06c71d69f0c"
/>
<img width="800" height="825" alt="Screenshot 2026-04-02 at 15 40 48"
src="https://github.com/user-attachments/assets/31b632b9-59cb-4bf7-a6a0-6158846fcf9a"
/>
<img width="800" height="812" alt="Screenshot 2026-04-02 at 15 40 54"
src="https://github.com/user-attachments/assets/64e38a15-2e56-4c0e-bd84-987bf6076bf7"
/>



**Why:** The existing onboarding flow was outdated and didn't align with
the new Autopilot-first experience. New users need a streamlined,
visually polished wizard that collects their role and pain points to
personalize Autopilot suggestions.

**What:** Complete redesign of the onboarding wizard as a 4-step flow:
Welcome → Role selection → Pain points → Preparing workspace. Uses the
design system throughout (atoms/molecules), adds animations, and syncs
steps with URL search params.

**How:** 
- Zustand store manages wizard state (name, role, pain points, current
step)
- Steps synced to `?step=N` URL params for browser navigation support
- Pain points reordered based on selected role (e.g. Sales sees "Finding
leads" first)
- Design system components used exclusively (no raw shadcn `ui/`
imports)
- New reusable components: `FadeIn` (atom), `TypingText` (molecule) with
Storybook stories
- `AutoGPTLogo` made sizeable via Tailwind className prop, migrated in
Navbar
- Fixed `SetupAnalytics` crash (client component was rendered inside
`<head>`)

### Changes 🏗️

- **New onboarding wizard** (`steps/WelcomeStep`, `RoleStep`,
`PainPointsStep`, `PreparingStep`)
- **New shared components**: `ProgressBar`, `StepIndicator`,
`SelectableCard`, `CardCarousel`
- **New design system components**: `FadeIn` atom with stories,
`TypingText` molecule with stories
- **`AutoGPTLogo`** — size now controlled via `className` prop instead
of numeric `size`
- **Navbar** — migrated from legacy `IconAutoGPTLogo` to design system
`AutoGPTLogo`
- **Layout fix** — moved `SetupAnalytics` from `<head>` to `<body>` to
fix React hydration crash
- **Role-based pain point ordering** — top picks surfaced first based on
role selection
- **URL-synced steps** — `?step=N` search params for back/forward
navigation
- Removed old onboarding pages (1-welcome through 6-congrats, reset
page)
- Emoji/image assets for role selection cards

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Complete onboarding flow from step 1 through 4 as a new user
  - [x] Verify back button navigates to previous step
  - [x] Verify progress bar advances correctly (hidden on step 4)
  - [x] Verify step indicator dots show for steps 1-3
  - [x] Verify role selection reorders pain points on next step
  - [x] Verify "Other" role/pain point shows text input
  - [x] Verify typing animation on PreparingStep title
  - [x] Verify fade-in animations on all steps
  - [x] Verify URL updates with `?step=N` on navigation
  - [x] Verify browser back/forward works with step URLs
  - [x] Verify mobile horizontal scroll on card grids
  - [x] Verify `pnpm types` passes cleanly

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 18:06:57 +07:00
Nicholas Tindle
a7f4093424 ci(platform): set up Codecov coverage reporting across platform and classic (#12655)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 03:48:30 -05:00
Nicholas Tindle
e33b1e2105 feat(classic): update classic autogpt a bit to make it more useful for my day to day (#11797)
## Summary

This PR modernizes AutoGPT Classic to make it more useful for day-to-day
autonomous agent development. Major changes include consolidating the
project structure, adding new prompt strategies, modernizing the
benchmark system, and improving the development experience.

**Note: AutoGPT Classic is an experimental, unsupported project
preserved for educational/historical purposes. Dependencies will not be
actively updated.**

## Changes 🏗️

### Project Structure & Build System
- **Consolidated Poetry projects** - Merged `forge/`,
`original_autogpt/`, and benchmark packages into a single
`pyproject.toml` at `classic/` root
- **Removed old benchmark infrastructure** - Deleted the complex
`agbenchmark` package (3000+ lines) in favor of the new
`direct_benchmark` harness
- **Removed frontend** - Deleted `benchmark/frontend/` React app (no
longer needed)
- **Cleaned up CI workflows** - Simplified GitHub Actions workflows for
the consolidated project structure
- **Added CLAUDE.md** - Documentation for working with the codebase
using Claude Code

### New Direct Benchmark System
- **`direct_benchmark` harness** - New streamlined benchmark runner
with:
  - Rich TUI with multi-panel layout showing parallel test execution
  - Incremental resume and selective reset capabilities
  - CI mode for non-interactive environments
  - Step-level logging with colored prefixes
  - "Would have passed" tracking for timed-out challenges
  - Copy-paste completion blocks for sharing results

### Multiple Prompt Strategies
Added pluggable prompt strategy system supporting:
- **one_shot** - Single-prompt completion
- **plan_execute** - Plan first, then execute steps
- **rewoo** - Reasoning without observation (deferred tool execution)
- **react** - Reason + Act iterative loop
- **lats** - Language Agent Tree Search (MCTS-based exploration)
- **sub_agent** - Multi-agent delegation architecture
- **debate** - Multi-agent debate for consensus

### LLM Provider Improvements
- Added support for modern **Anthropic Claude models**
(claude-3.5-sonnet, claude-3-haiku, etc.)
- Added **Groq** provider support
- Improved tool call error feedback for LLM self-correction
- Fixed deprecated API usage

### Web Components
- **Replaced Selenium with Playwright** for web browsing (better async
support, faster)
- Added **lightweight web fetch component** for simple URL fetching
- **Modernized web search** with tiered provider system (Tavily, Serper,
Google)

### Agent Capabilities
- **Workspace permissions system** - Pattern-based allow/deny lists for
agent commands
- **Rich interactive selector** for command approval with scopes
(once/agent/workspace/deny)
- **TodoComponent** with LLM-powered task decomposition
- **Platform blocks integration** - Connect to AutoGPT Platform API for
additional blocks
- **Sub-agent architecture** - Agents can spawn and coordinate
sub-agents

### Developer Experience
- **Python 3.12+ support** with CI testing on 3.12, 3.13, 3.14
- **Current working directory as default workspace** - Run `autogpt`
from any project directory
- Simplified log format (removed timestamps)
- Improved configuration and setup flow
- External benchmark adapters for GAIA, SWE-bench, and AgentBench

### Bug Fixes
- Fixed N/A command loop when using native tool calling
- Fixed auto-advance plan steps in Plan-Execute strategy
- Fixed approve+feedback to execute command then send feedback
- Fixed parallel tool calls in action history
- Always recreate Docker containers for code execution
- Various pyright type errors resolved
- Linting and formatting issues fixed across codebase

## Test Plan

- [x] CI lint, type, and test checks pass
- [x] Run `poetry install` from `classic/` directory
- [x] Run `poetry run autogpt` and verify CLI starts
- [x] Run `poetry run direct-benchmark run --tests ReadFile` to verify
benchmark works

## Notes

- This is a WIP PR for personal use improvements
- The project is marked as **unsupported** - no active maintenance
planned
- Contains known vulnerabilities in dependencies (intentionally not
updated)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> CI/build workflows are substantially reworked (runner matrix removal,
path/layout changes, new benchmark runner), so breakage is most likely
in automation and packaging rather than runtime behavior.
> 
> **Overview**
> **Modernizes the `classic/` project layout and automation around a
single consolidated Poetry project** (root
`classic/pyproject.toml`/`poetry.lock`) and updates docs
(`classic/README.md`, new `classic/CLAUDE.md`) accordingly.
> 
> **Replaces the old `agbenchmark` CI usage with `direct-benchmark` in
GitHub Actions**, including new/updated benchmark smoke and regression
workflows, standardized `working-directory: classic`, and a move to
**Python 3.12** on Ubuntu-only runners (plus updated caching, coverage
flags, and required `ANTHROPIC_API_KEY` wiring).
> 
> Cleans up repo/dev tooling by removing the classic frontend workflow,
deleting the Forge VCR cassette submodule (`.gitmodules`) and associated
CI steps, consolidating `flake8`/`isort`/`pyright` pre-commit hooks to
run from `classic/`, updating ignores for new report/workspace
artifacts, and updating `classic/Dockerfile.autogpt` to build from
Python 3.12 with the consolidated project structure.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
de67834dac. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-03 07:16:36 +00:00
Zamil Majdy
fff101e037 feat(backend): add SQL query block with multi-database support for CoPilot analytics (#12569)
## Summary
- Add a read-only SQL query block for CoPilot/AutoPilot analytics access
- Supports **multiple databases**: PostgreSQL, MySQL, SQLite, MSSQL via
SQLAlchemy
- Enforces read-only queries (SELECT only) with defense-in-depth SQL
validation using sqlparse
- SSRF protection: blocks connections to private/internal IPs
- Credentials stored securely via the platform credential system

## Changes
- New `SQLQueryBlock` in `backend/blocks/sql_query_block.py` with
`DatabaseType` enum
- SQLAlchemy-based execution with dialect-specific read-only and timeout
settings
- Connection URL validation ensuring driver matches selected database
type
- Comprehensive test suite (62 tests) including URL validation,
sanitization, serialization
- Documentation in `docs/integrations/block-integrations/data.md`
- Added `DATABASE` provider to `ProviderName` enum

### Checklist 📋
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

#### Test plan:
- [x] Unit tests pass for query validation, URL validation, error
sanitization, value serialization
- [x] Read-only enforcement rejects INSERT/UPDATE/DELETE/DROP
- [x] Multi-statement injection blocked
- [x] SSRF protection blocks private IPs
- [x] Connection URL driver validation works for all 4 database types

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 06:43:40 +00:00
Zamil Majdy
f1ac05b2e0 fix(backend): propagate dry-run mode to special blocks with LLM-powered simulation (#12575)
## Summary
- **OrchestratorBlock & AgentExecutorBlock** now execute for real in
dry-run mode so the orchestrator can make LLM calls and agent executors
can spawn child graphs. Their downstream tool blocks and child-graph
blocks are still simulated via `simulate_block()`. Credential fields
from node defaults are restored since `validate_exec()` wipes them in
dry-run mode. Agent-mode iterations capped at 1 in dry-run.
- **All blocks** (including MCPToolBlock) are simulated via a single
generic `simulate_block()` path. The LLM prompt is grounded by
`inspect.getsource(block.run)`, giving the simulator access to the exact
implementation of each block's `run()` method. This produces realistic
mock responses for any block type without needing block-specific
simulation logic.
- Updated agent generation guide to document special block dry-run
behavior.
- Minor frontend fixes: exported `formatCents` from
`RateLimitResetDialog` for reuse in `UsagePanelContent`, used `useRef`
for stable callback references in `useResetRateLimit` to avoid stale
closures.
- 74 tests (21 existing dry-run + 53 new simulator tests covering prompt
building, passthrough logic, and special block dry-run).

## Design

The simulator (`backend/executor/simulator.py`) uses a two-tier
approach:

1. **Passthrough blocks** (OrchestratorBlock, AgentExecutorBlock):
`prepare_dry_run()` returns modified input_data so these blocks execute
for real in `manager.py`. OrchestratorBlock gets `max_iterations=1`
(agent mode) or 0 (traditional mode). AgentExecutorBlock spawns real
child graph executions whose blocks inherit `dry_run=True`.

2. **All other blocks**: `simulate_block()` builds an LLM prompt
containing:
   - Block name and description
   - Input/output schemas (JSON Schema)
   - The block's `run()` source code via `inspect.getsource(block.run)`
- The actual input values (with credentials stripped and long values
truncated)

The LLM then role-plays the block's execution, producing realistic
outputs grounded in the actual implementation.

Special handling for input/output blocks: `AgentInputBlock` and
`AgentOutputBlock` are pure passthrough (no LLM call needed).

## Test plan
- [x] All 74 tests pass (`pytest backend/copilot/tools/test_dry_run.py
backend/executor/simulator_test.py`)
- [x] Pre-commit hooks pass (ruff, isort, black, pyright, frontend
typecheck)
- [x] CI: all checks green
- [x] E2E: dry-run execution completes with `is_dry_run=true`, cost=0,
no errors
- [x] E2E: normal (non-dry-run) execution unchanged
- [x] E2E: Create agent with OrchestratorBlock + tool blocks, run with
`dry_run=True`, verify orchestrator makes real LLM calls while tool
blocks are simulated
- [x] E2E: AgentExecutorBlock spawns child graph in dry-run, child
blocks are LLM-simulated
- [x] E2E: Builder simulate button works end-to-end with special blocks

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 17:09:55 +00:00
Zamil Majdy
f115607779 fix(copilot): recognize Agent tool name and route CLI state into workspace (#12635)
### Why / What / How

**Why:** The Claude Agent SDK CLI renamed the sub-agent tool from
`"Task"` to `"Agent"` in v2.x. Our security hooks only checked for
`"Task"`, so all sub-agent security controls were silently bypassed on
production: concurrency limiting didn't apply, and slot tracking was
broken. This was discovered via Langfuse trace analysis of session
`62b1b2b9` where background sub-agents ran unchecked.

Additionally, the CLI writes sub-agent output to `/tmp/claude-<uid>/`
and project state to `$HOME/.claude/` — both outside the per-session
workspace (`/tmp/copilot-<session>/`). This caused `PermissionError` in
E2B sandboxes and silently lost sub-agent results.

The frontend also had no rendering for the `Agent` / `TaskOutput` SDK
built-in tools — they fell through to the generic "other" category with
no context-aware display.

**What:**
1. Fix the sub-agent tool name recognition (`"Task"` → `{"Task",
"Agent"}`)
2. Allow `run_in_background` — the SDK handles async lifecycle cleanly
(returns `isAsync:true`, model polls via `TaskOutput`)
3. Route CLI state into the workspace via `CLAUDE_CODE_TMPDIR` and
`HOME` env vars
4. Add lifecycle hooks (`SubagentStart`/`SubagentStop`) for
observability
5. Add frontend `"agent"` tool category with proper UI rendering

**How:**
- Security hooks check `tool_name in _SUBAGENT_TOOLS` (frozenset of
`"Task"` and `"Agent"`)
- Background agents are allowed but still count against `max_subtasks`
concurrency limit
- Frontend detects `isAsync: true` output → shows "Agent started
(background)" not "Agent completed"
- `TaskOutput` tool shows retrieval status and collected results
- Robot icon and agent-specific accordion rendering for both foreground
and background agents

### Changes 🏗️

**Backend:**
- **`security_hooks.py`**: Replace `tool_name == "Task"` with `tool_name
in _SUBAGENT_TOOLS`. Remove `run_in_background` deny block (SDK handles
async lifecycle). Add `SubagentStart`/`SubagentStop` hooks.
- **`tool_adapter.py`**: Add `"Agent"` to `_SDK_BUILTIN_ALWAYS` list
alongside `"Task"`.
- **`service.py`**: Set `CLAUDE_CODE_TMPDIR=sdk_cwd` and `HOME=sdk_cwd`
in SDK subprocess env.
- **`security_hooks_test.py`**: Update background tests (allowed, not
blocked). Add test for background agents counting against concurrency
limit.

**Frontend:**
- **`GenericTool/helpers.ts`**: Add `"agent"` tool category for `Agent`,
`Task`, `TaskOutput`. Agent-specific animation text detecting `isAsync`
output. Input summaries from description/prompt fields.
- **`GenericTool/GenericTool.tsx`**: Add `RobotIcon` for agent category.
Add `getAgentAccordionData()` with async-aware title/content.
`TaskOutput` shows retrieval status.
- **`useChatSession.ts`**: Fix pre-existing TS error (void mutation
body).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] All security hooks tests pass (background allowed + limit
enforced)
  - [x] Pre-commit hooks (ruff, black, isort, pyright, tsc) all pass
  - [x] E2E test: copilot agent create+run scenario PASS
- [ ] Deploy to dev and test copilot sub-agent spawning with background
mode

#### For configuration changes:
- [x] `.env.default` is updated or already compatible
- [x] `docker-compose.yml` is updated or already compatible
2026-04-03 00:09:19 +07:00
Zamil Majdy
1aef8b7155 fix(backend/copilot): fix tool output file reading between E2B and host (#12646)
### Why / What / How

**Why:** When copilot tools return large outputs (e.g. 3MB+ base64
images from API calls), the agent cannot process them in the E2B
sandbox. Three compounding issues prevent seamless file access:
1. The `<tool-output-truncated path="...">` tag uses a bare `path=`
attribute that the model confuses with a local filesystem path (it's
actually a workspace path)
2. `is_allowed_local_path` rejects `tool-outputs/` directories (only
`tool-results/` was allowed)
3. SDK-internal files read via the `Read` tool are not available in the
E2B sandbox for `bash_exec` processing

**What:** Fixes all three issues so that large tool outputs can be
seamlessly read and processed in both host and E2B contexts.

**How:**
- Changed `path=` → `workspace_path=` in the truncation tag to
disambiguate workspace vs filesystem paths
- Added `save_to_path` guidance in the retrieval instructions for E2B
users
- Extended `is_allowed_local_path` to accept both `tool-results` and
`tool-outputs` directories
- Added automatic bridging: when E2B is active and `Read` accesses an
SDK-internal file, the file is automatically copied to `/tmp/<filename>`
in the sandbox
- Updated system prompting to explain both SDK tool-result bridging and
workspace `<tool-output-truncated>` handling

### Changes 🏗️

- **`tools/base.py`**: `_persist_and_summarize` now uses
`workspace_path=` attribute and includes `save_to_path` example for E2B
processing
- **`context.py`**: `is_allowed_local_path` accepts both `tool-results`
and `tool-outputs` directory names
- **`sdk/e2b_file_tools.py`**: `_handle_read_file` bridges SDK-internal
files to `/tmp/` in E2B sandbox; new `_bridge_to_sandbox` helper
- **`prompting.py`**: Updated "SDK tool-result files" section and added
"Large tool outputs saved to workspace" section
- **Tests**: Added `tool-outputs` path validation tests in
`context_test.py` and `e2b_file_tools_test.py`; updated `base_test.py`
assertion for `workspace_path`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `poetry run pytest backend/copilot/tools/base_test.py` — all 9
tests pass (persistence, truncation, binary fields)
  - [x] `poetry run format` and `poetry run lint` pass clean
  - [x] All pre-commit hooks pass
- [ ] `context_test.py`, `e2b_file_tools_test.py`,
`security_hooks_test.py` — blocked by pre-existing DB migration issue on
worktree (missing `User.subscriptionTier` column); CI will validate
these
2026-04-03 00:08:04 +07:00
Nicholas Tindle
0da949ba42 feat(e2b): set git committer identity from user's GitHub profile (#12650)
## Summary

Sets git author/committer identity in E2B sandboxes using the user's
connected GitHub account profile, so commits are properly attributed.

## Changes

### `integration_creds.py`
- Added `get_github_user_git_identity(user_id)` that fetches the user's
name and email from the GitHub `/user` API
- Uses TTL cache (10 min) to avoid repeated API calls
- Falls back to GitHub noreply email
(`{id}+{login}@users.noreply.github.com`) when user has a private email
- Falls back to `login` if `name` is not set

### `bash_exec.py`
- After injecting integration env vars, calls
`get_github_user_git_identity()` and sets `GIT_AUTHOR_NAME`,
`GIT_AUTHOR_EMAIL`, `GIT_COMMITTER_NAME`, `GIT_COMMITTER_EMAIL`
- Only sets these if the user has a connected GitHub account

### `bash_exec_test.py`
- Added tests covering: identity set from GitHub profile, no identity
when GitHub not connected, no injection when no user_id

## Why
Previously, commits made inside E2B sandboxes had no author identity
set, leading to unattributed commits. This dynamically resolves identity
from the user's actual GitHub account rather than hardcoding a default.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds outbound calls to GitHub’s `/user` API during `bash_exec` runs
and injects returned identity into the sandbox environment, which could
impact reliability (network/timeouts) and attribution behavior. Caching
mitigates repeated calls but incorrect/expired tokens or API failures
may lead to missing identity in commits.
> 
> **Overview**
> Sets git author/committer environment variables in the E2B `bash_exec`
path by fetching the connected user’s GitHub profile and injecting
`GIT_AUTHOR_*`/`GIT_COMMITTER_*` into the sandbox env.
> 
> Introduces `get_github_user_git_identity()` with TTL caching
(including a short-lived null cache), fallback to GitHub noreply email
when needed, and ensures `invalidate_user_provider_cache()` also clears
identity caches for the `github` provider. Updates tests to cover
identity injection behavior and the new cache invalidation semantics.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
955ec81efe. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: AutoGPT <autopilot@agpt.co>
2026-04-02 15:07:22 +00:00
Zamil Majdy
6b031085bd feat(platform): add generic ask_question copilot tool (#12647)
### Why / What / How

**Why:** The copilot can ask clarifying questions in plain text, but
that text gets collapsed into hidden "reasoning" UI when the LLM also
calls tools in the same turn. This makes clarification questions
invisible to users. The existing `ClarificationNeededResponse` model and
`ClarificationQuestionsCard` UI component were built for this purpose
but had no tool wiring them up.

**What:** Adds a generic `ask_question` tool that produces a visible,
interactive clarification card instead of collapsible plain text. Unlike
the agent-generation-specific `clarify_agent_request` proposed in
#12601, this tool is workflow-agnostic — usable for agent building,
editing, troubleshooting, or any flow needing user input.

**How:** 
- Backend: New `AskQuestionTool` reuses existing
`ClarificationNeededResponse` model. Registered in `TOOL_REGISTRY` and
`ToolName` permissions.
- Frontend: New `AskQuestion/` renderer reuses
`ClarificationQuestionsCard` from CreateAgent. Registered in
`CUSTOM_TOOL_TYPES` (prevents collapse into reasoning) and
`MessagePartRenderer`.
- Guide: `agent_generation_guide.md` updated to reference `ask_question`
for the clarification step.

### Changes 🏗️

- **`copilot/tools/ask_question.py`** — New generic tool: takes
`question`, optional `options[]` and `keyword`, returns
`ClarificationNeededResponse`
- **`copilot/tools/__init__.py`** — Register `ask_question` in
`TOOL_REGISTRY`
- **`copilot/permissions.py`** — Add `ask_question` to `ToolName`
literal
- **`copilot/sdk/agent_generation_guide.md`** — Reference `ask_question`
tool in clarification step
- **`ChatMessagesContainer/helpers.ts`** — Add `tool-ask_question` to
`CUSTOM_TOOL_TYPES`
- **`MessagePartRenderer.tsx`** — Add switch case for
`tool-ask_question`
- **`AskQuestion/AskQuestion.tsx`** — Renderer reusing
`ClarificationQuestionsCard`
- **`AskQuestion/helpers.ts`** — Output parsing and animation text

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Backend format + pyright pass
  - [x] Frontend lint + types pass
  - [x] Pre-commit hooks pass
- [ ] Manual test: copilot uses `ask_question` and card renders visibly
(not collapsed)
2026-04-02 12:56:48 +00:00
Toran Bruce Richards
11b846dd49 fix(blocks): rename placeholder_values to options on AgentDropdownInputBlock (#12595)
## Summary

Resolves [REQ-78](https://linear.app/autogpt/issue/REQ-78): The
`placeholder_values` field on `AgentDropdownInputBlock` is misleadingly
named. In every major UI framework "placeholder" means non-binding hint
text that disappears on focus, but this field actually creates a
dropdown selector that restricts the user to only those values.

## Changes

### Core rename (`autogpt_platform/backend/backend/blocks/io.py`)
- Renamed `placeholder_values` → `options` on
`AgentDropdownInputBlock.Input`
- Added clear field description: *"If provided, renders the input as a
dropdown selector restricted to these values. Leave empty for free-text
input."*
- Updated class docstring to describe actual behavior
- Overrode `model_construct()` to remap legacy `placeholder_values` →
`options` for **backward compatibility** with existing persisted agent
JSON

### Tests (`autogpt_platform/backend/backend/blocks/test/test_block.py`)
- Updated existing tests to use canonical `options` field name
- Added 2 new backward-compat tests verifying legacy
`placeholder_values` still works through both `model_construct()` and
`Graph._generate_schema()` paths

### Documentation
- Updated
`autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md`
— changed field name in CoPilot SDK guide
- Updated `docs/integrations/block-integrations/basic.md` — changed
field name and description in public docs

### Load tests
(`autogpt_platform/backend/load-tests/tests/api/graph-execution-test.js`)
- Removed spurious `placeholder_values: {}` from AgentInputBlock node
(this field never existed on AgentInputBlock)
- Fixed execution input to use `value` instead of `placeholder_values`

## Backward Compatibility

Existing agents with `placeholder_values` in their persisted
`input_default` JSON will continue to work — the `model_construct()`
override transparently remaps the old key to `options`. No database
migration needed since the field is stored inside a JSON blob, not as a
dedicated column.

## Testing

- All existing tests updated and passing
- 2 new backward-compat tests added
- No frontend changes needed (frontend reads `enum` from generated JSON
Schema, not the field name directly)

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-02 05:56:17 +00:00
Zamil Majdy
b9e29c96bd fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype (#12642)
## Why

PR #12625 fixed the prompt-too-long retry mechanism for most paths, but
two SDK-specific paths were still broken. The dev session `d2f7cba3`
kept accumulating synthetic "Prompt is too long" error entries on every
turn, growing the transcript from 2.5 MB → 3.2 MB, making recovery
impossible.

Root causes identified from production logs (`[T25]`, `[T28]`):

**Path 1 — AssistantMessage content check:**
When the Claude API rejects a prompt, the SDK surfaces it as
`AssistantMessage(error="invalid_request", content=[TextBlock("Prompt is
too long")])`. Our check only inspected `error_text = str(sdk_error)`
which is `"invalid_request"` — not a prompt-too-long pattern. The
content was then streamed out as `StreamText`, setting `events_yielded =
1`, which blocked retry even when the ResultMessage fired.

**Path 2 — ResultMessage success subtype:**
After the SDK auto-compacts internally (via `PreCompact` hook) and the
compacted transcript is _still_ too long, the SDK returns
`ResultMessage(subtype="success", result="Prompt is too long")`. Our
check only ran for `subtype="error"`. With `subtype="success"`, the
stream "completed normally", appended the synthetic error entry to the
transcript via `transcript_builder`, and uploaded it to GCS — causing
the transcript to grow on each failed turn.

## What

- **AssistantMessage handler**: when `sdk_error` is set, also check the
content text. `sdk_error` being non-`None` confirms this is an API error
message (not user-generated content), so content inspection is safe.
- **ResultMessage handler**: check `result` for prompt-too-long patterns
regardless of `subtype`, covering the SDK auto-compact path where
`subtype="success"` with `result="Prompt is too long"`.

## How

Two targeted one-line condition expansions in `_run_stream_attempt`,
plus two new integration tests in `retry_scenarios_test.py` that
reproduce each broken path and verify retry fires correctly.

## Changes

- `backend/copilot/sdk/service.py`: fix AssistantMessage content check +
ResultMessage subtype-independent check
- `backend/copilot/sdk/retry_scenarios_test.py`: add 2 integration tests
for the new scenarios

## Checklist

- [x] Tests added for both new scenarios (45 total, all pass)
- [x] Formatted (`poetry run format`)
- [x] No false-positive risk: AssistantMessage check gated behind
`sdk_error is not None`
- [x] Root cause verified from production pod logs
2026-04-01 22:32:09 +00:00
Zamil Majdy
4ac0ba570a fix(backend): fix copilot credential loading across event loops (#12628)
## Why

CoPilot autopilot sessions are inconsistently failing to load user
credentials (specifically GitHub OAuth). Some sessions proceed normally,
some show "provide credentials" prompts despite the user having valid
creds, and some are completely blocked.

Production logs confirmed the root cause: `RuntimeError: Task got Future
<Future pending> attached to a different loop` in the credential refresh
path, cascading into null-cache poisoning that blocks credential lookups
for 60 seconds.

## What

Three interrelated bugs in the credential system:

1. **`refresh_if_needed` always acquired Redis locks even with
`lock=False`** — The `lock` parameter only controlled the inner
credential lock, but the outer "refresh" scope lock was always acquired.
The copilot executor uses multiple worker threads with separate event
loops; the `asyncio.Lock` inside `AsyncRedisKeyedMutex` was bound to one
loop and failed on others.

2. **Stale event loop in `locks()` singleton** — Both
`IntegrationCredentialsManager` and `IntegrationCredentialsStore` cached
their `AsyncRedisKeyedMutex` without tracking which event loop created
it. When a different worker thread (with a different loop) reused the
singleton, it got the "Future attached to different loop" error.

3. **Null-cache poisoning on refresh failure** — When OAuth refresh
failed (due to the event loop error), the code fell through to cache "no
credentials found" for 60 seconds via `_null_cache`. This blocked ALL
subsequent credential lookups for that user+provider, even though the
credentials existed and could refresh fine on retry.

## How

- Split `refresh_if_needed` into `_refresh_locked` / `_refresh_unlocked`
so `lock=False` truly skips ALL Redis locking (safe for copilot's
best-effort background injection)
- Added event loop tracking to `locks()` in both
`IntegrationCredentialsManager` and `IntegrationCredentialsStore` —
recreates the mutex when the running loop changes
- Only populate `_null_cache` when the user genuinely has no
credentials; skip caching when OAuth refresh failed transiently
- Updated existing test to verify null-cache is not poisoned on refresh
failure

## Test plan

- [x] All 14 existing `integration_creds_test.py` tests pass
- [x] Updated
`test_oauth2_refresh_failure_returns_none_without_null_cache` verifies
null-cache is not populated on refresh failure
- [x] Format, lint, and typecheck pass
- [ ] Deploy to staging and verify copilot sessions consistently load
GitHub credentials
2026-04-02 00:11:38 +07:00
Zamil Majdy
d61a2c6cd0 Revert "fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype"
This reverts commit 1c301b4b61.
2026-04-01 18:59:38 +02:00
Zamil Majdy
1c301b4b61 fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype
The SDK returns AssistantMessage(error="invalid_request", content=[TextBlock("Prompt is too long")])
followed by ResultMessage(subtype="success", result="Prompt is too long") when the transcript is
rejected after internal auto-compaction. Both paths bypassed the retry mechanism:

- AssistantMessage handler only checked error_text ("invalid_request"), not the content which
  holds the actual error description. The content was then streamed as text, setting events_yielded=1,
  which blocked retry even when ResultMessage fired.
- ResultMessage handler only triggered prompt-too-long detection for subtype="error", not
  subtype="success". The stream "completed normally", stored the synthetic error entry in the
  transcript, and uploaded it — causing the transcript to grow unboundedly on each failed turn.

Fixes:
1. AssistantMessage handler: when sdk_error is set (confirmed error message), also check content
   text. sdk_error being set guarantees this is an API error, not user-generated content, so
   content inspection is safe.
2. ResultMessage handler: check result for prompt-too-long regardless of subtype, covering the
   case where the SDK auto-compacts internally but the result is still too long.

Adds integration tests for both new scenarios.
2026-04-01 18:28:46 +02:00
Zamil Majdy
24d0c35ed3 fix(backend/copilot): prompt-too-long retry, compaction churn, model-aware compression, and truncated tool call recovery (#12625)
## Why

CoPilot has several context management issues that degrade long
sessions:
1. "Prompt is too long" errors crash the session instead of triggering
retry/compaction
2. Stale thinking blocks bloat transcripts, causing unnecessary
compaction every turn
3. Compression target is hardcoded regardless of model context window
size
4. Truncated tool calls (empty `{}` args from max_tokens) kill the
session instead of guiding the model to self-correct

## What

**Fix 1: Prompt-too-long retry bypass (SENTRY-1207)**
The SDK surfaces "prompt too long" via `AssistantMessage.error` and
`ResultMessage.result` — neither triggered the retry/compaction loop
(only Python exceptions did). Now both paths are intercepted and
re-raised.

**Fix 2: Strip stale thinking blocks before upload**
Thinking/redacted_thinking blocks in non-last assistant entries are
10-50K tokens each but only needed for API signature verification in the
*last* message. Stripping before upload reduces transcript size and
prevents per-turn compaction.

**Fix 3: Model-aware compression target**
`compress_context()` now computes `target_tokens` from the model's
context window (e.g. 140K for Opus 200K) instead of a hardcoded 120K
default. Larger models retain more history; smaller models compress more
aggressively.

**Fix 4: Self-correcting truncated tool calls**
When the model's response exceeds max_tokens, tool call inputs get
silently truncated to `{}`. Previously this tripped a circuit breaker
after 3 attempts. Now the MCP wrapper detects empty args and returns
guidance: "write in chunks with `cat >>`, pass via
`@@agptfile:filename`". The model can self-correct instead of the
session dying.

## How

- **service.py**: `_is_prompt_too_long` checks in both
`AssistantMessage.error` and `ResultMessage` error handlers. Circuit
breaker limit raised from 3→5.
- **transcript.py**: `strip_stale_thinking_blocks()` reverse-scans for
last assistant `message.id`, strips thinking blocks from all others.
Called in `upload_transcript()`.
- **prompt.py**: `get_compression_target(model)` computes
`context_window - 60K overhead`. `compress_context()` uses it when
`target_tokens` is None.
- **tool_adapter.py**: `_truncating` wrapper intercepts empty args on
tools with required params, returns actionable guidance instead of
failing.

## Related

- Fixes SENTRY-1207
- Sessions: `d2f7cba3` (repeated compaction), `08b807d4` (prompt too
long), `130d527c` (truncated tool calls)
- Extends #12413, consolidates #12626

## Test plan

- [x] 6 unit tests for `strip_stale_thinking_blocks`
- [x] 1 integration test for ResultMessage prompt-too-long → compaction
retry
- [x] Pyright clean (0 errors), all pre-commit hooks pass
- [ ] E2E: Load transcripts from affected sessions and verify behavior
2026-04-01 15:10:57 +00:00
Zamil Majdy
8aae7751dc fix(backend/copilot): prevent duplicate block execution from pre-launch arg mismatch (#12632)
## Why

CoPilot sessions are duplicating Linear tickets and GitHub PRs.
Investigation of 5 production sessions (March 31st) found that 3/5
created duplicate Linear issues — each with consecutive IDs at the exact
same timestamp, but only one visible in Langfuse traces.

Production gcloud logs confirm: **279 arg mismatch warnings per day**,
**37 duplicate block execution pairs**, and all LinearCreateIssueBlock
failures in pairs.

Related: SECRT-2204

## What

Replace the speculative pre-launch mechanism with the SDK's native
parallel dispatch via `readOnlyHint` tool annotations. Remove ~580 lines
of pre-launch infrastructure code.

## How

### Root cause
The pre-launch mechanism had three compounding bugs:
1. **Arg mismatch**: The SDK CLI normalises args between the
`AssistantMessage` (used for pre-launch) and the MCP `tools/call`
dispatch, causing frequent mismatches (279/day in prod)
2. **FIFO desync on denial**: Security hooks can deny tool calls,
causing the CLI to skip the MCP dispatch — but the pre-launched task
stays in the FIFO queue, misaligning all subsequent matches
3. **Cancel race**: `task.cancel()` is best-effort in asyncio — if the
HTTP call to Linear/GitHub already completed, the side effect is
irreversible

### Fix
- **Removed** `pre_launch_tool_call()`, `cancel_pending_tool_tasks()`,
`_tool_task_queues` ContextVar, all FIFO queue logic, and all 4
`cancel_pending_tool_tasks()` calls in `service.py`
- **Added** `readOnlyHint=True` annotations on 15+ read-only tools
(`find_block`, `search_docs`, `list_workspace_files`, etc.) — the SDK
CLI natively dispatches these in parallel ([ref:
anthropics/claude-code#14353](https://github.com/anthropics/claude-code/issues/14353))
- Side-effect tools (`run_block`, `bash_exec`, `create_agent`, etc.)
have no annotation → CLI runs them sequentially → no duplicate execution
risk

### Net change: -578 lines, +105 lines
2026-04-01 13:42:54 +00:00
An Vy Le
725da7e887 dx(backend/copilot): clarify ambiguous agent goals using find_block before generation (#12601)
### Why / What / How

**Why:** When a user asks CoPilot to build an agent with an ambiguous
goal (output format, delivery channel, data source, or trigger
unspecified), the agent generator previously made assumptions and jumped
straight into JSON generation. This produced agents that didn't match
what the user actually wanted, requiring multiple correction cycles.

**What:** Adds a "Clarifying Before Building" section to the agent
generation guide. When the goal is ambiguous, CoPilot first calls
`find_block` to discover what the platform actually supports for the
ambiguous dimension, then asks the user one concrete question grounded
in real platform options (e.g. "The platform supports Gmail, Slack, and
Google Docs — which should the agent use for delivery?"). Only after the
user answers does the full agent generation workflow proceed.

**How:** The clarification instruction is added to
`agent_generation_guide.md` — the guide loaded on-demand via
`get_agent_building_guide` when the LLM is about to build an agent. This
avoids polluting the system prompt supplement (which loads for every
CoPilot conversation, not just agent building). No dedicated tool is
needed — the LLM asks naturally in conversation text after discovering
real platform options via `find_block`.

### Changes 🏗️

- `backend/copilot/sdk/agent_generation_guide.md`: Adds "Clarifying
Before Building" section before the workflow steps. Instructs the model
to call `find_block` for the ambiguous dimension, ask the user one
grounded question, wait for the answer, then proceed to generation.
- `backend/copilot/prompting_test.py`: New test file verifying the guide
contains the clarification section and references `find_block`.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Ask CoPilot to "build an agent to send a report" (ambiguous
output) — verify it calls `find_block` for delivery options and asks one
grounded question before generating JSON
- [ ] Ask CoPilot to "build an agent to scrape prices from Amazon and
email me daily" (specific goal) — verify it skips clarification and
proceeds directly to agent generation
- [ ] Verify the clarification question lists real block options (e.g.
Gmail, Slack, Google Docs) rather than abstract options

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-01 13:32:12 +00:00
seer-by-sentry[bot]
bd9e9ec614 fix(frontend): remove LaunchDarkly local storage bootstrapping (#12606)
### Why / What / How

<!-- Why: Why does this PR exist? What problem does it solve, or what's
broken/missing without it? -->
This PR fixes
[BUILDER-7HD](https://sentry.io/organizations/significant-gravitas/issues/7374387984/).
The issue was that: LaunchDarkly SDK fails to construct streaming URL
due to non-string `_url` from malformed `localStorage` bootstrap data.
<!-- What: What does this PR change? Summarize the changes at a high
level. -->
Removed the `bootstrap: "localStorage"` option from the LaunchDarkly
provider configuration.
<!-- How: How does it work? Describe the approach, key implementation
details, or architecture decisions. -->
This change ensures that LaunchDarkly no longer attempts to load initial
feature flag values from local storage. Flag values will now always be
fetched directly from the LaunchDarkly service, preventing potential
issues with stale local storage data.

### Changes 🏗️

<!-- List the key changes. Keep it higher level than the diff but
specific enough to highlight what's new/modified. -->
- Removed the `bootstrap: "localStorage"` option from the LaunchDarkly
provider configuration.
- LaunchDarkly will now always fetch flag values directly from its
service, bypassing local storage.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [ ] Verify that LaunchDarkly flags are loaded correctly without
issues.
- [ ] Ensure no errors related to `localStorage` or streaming URL
construction appear in the console.

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:

- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
2026-04-01 19:12:54 +07:00
Nicholas Tindle
88589764b5 dx(platform): normalize agent instructions for Claude and Codex (#12592)
### Why / What / How

Why: repo guidance was split between Claude-specific `CLAUDE.md` files
and Codex-specific `AGENTS.md` files, which duplicated instruction
content and made the same repository behave differently across agents.
The repo also had Claude skills under `.claude/skills` but no
Codex-visible repo skill path.

What: this PR bridges the repo's Claude skills into Codex and normalizes
shared instruction files so `AGENTS.md` becomes the canonical source
while each `CLAUDE.md` imports its sibling `AGENTS.md`.

How: add a repo-local `.agents/skills` symlink pointing to
`../.claude/skills`; move nested `CLAUDE.md` content into sibling
`AGENTS.md` files; replace each repo `CLAUDE.md` with a one-line
`@AGENTS.md` shim so Claude and Codex read the same scoped guidance
without duplicating text. The root `CLAUDE.md` now imports the root
`AGENTS.md` rather than symlinking to it.

Note: the instruction-file normalization commit was created with
`--no-verify` because the repo's frontend pre-commit `tsc` hook
currently fails on unrelated existing errors, largely missing
`autogpt_platform/frontend/src/app/api/__generated__/*` modules.

### Changes 🏗️

- Add `.agents/skills` as a repo-local symlink to `../.claude/skills` so
Codex discovers the existing Claude repo skills.
- Add a real root `CLAUDE.md` shim that imports the canonical root
`AGENTS.md`.
- Promote nested scoped instruction content into sibling `AGENTS.md`
files under `autogpt_platform/`, `autogpt_platform/backend/`,
`autogpt_platform/frontend/`, `autogpt_platform/frontend/src/tests/`,
and `docs/`.
- Replace the corresponding nested `CLAUDE.md` files with one-line
`@AGENTS.md` shims.
- Preserve the existing scoped instruction hierarchy while making the
shared content cross-compatible between Claude and Codex.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified `.agents/skills` resolves to `../.claude/skills`
  - [x] Verified each repo `CLAUDE.md` now contains only `@AGENTS.md`
- [x] Verified the expected `AGENTS.md` files exist at the root and
nested scoped directories
- [x] Verified the branch contains only the intended agent-guidance
commits relative to `dev` and the working tree is clean

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No runtime configuration changes are included in this PR.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: documentation/instruction-file reshuffle plus an
`.agents/skills` pointer; no runtime code paths are modified.
> 
> **Overview**
> Unifies agent guidance so **`AGENTS.md` becomes canonical** and all
corresponding `CLAUDE.md` files become 1-line shims (`@AGENTS.md`) at
the repo root, `autogpt_platform/`, backend, frontend, frontend tests,
and `docs/`.
> 
> Adds `.agents/skills` pointing to `../.claude/skills` so non-Claude
agents discover the same shared skills/instructions, eliminating
duplicated/agent-specific guidance content.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
839483c3b6. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-04-01 09:08:51 +00:00
Zamil Majdy
c659f3b058 fix(copilot): fix dry-run simulation showing INCOMPLETE/error status (#12580)
## Summary
- **Backend**: Strip empty `error` pins from dry-run simulation outputs
that the simulator always includes (set to `""` meaning "no error").
This was causing the LLM to misinterpret successful simulations as
failures and report "INCOMPLETE" status to users
- **Backend**: Add explicit "Status: COMPLETED" to dry-run response
message to prevent LLM misinterpretation
- **Backend**: Update simulation prompt to exclude `error` from the
"MUST include" keys list, and instruct LLM to omit error unless
simulating a logical failure
- **Frontend**: Fix `isRunBlockErrorOutput()` type guard that was too
broad (`"error" in output` matched BlockOutputResponse objects, not just
ErrorResponse), causing dry-run results to be displayed as errors
- **Frontend**: Fix `parseOutput()` fallback matching to not classify
BlockOutputResponse as ErrorResponse
- **Frontend**: Filter out empty error pins from `BlockOutputCard`
display and accordion metadata output key counting
- **Frontend**: Clear stale execution results before dry-run/no-input
runs so the UI shows fresh output
- **Frontend**: Fix first-click simulate race condition by invalidating
execution details query after WebSocket subscription confirms

## Test plan
- [x] All 12 existing + 5 new dry-run tests pass (`poetry run pytest
backend/copilot/tools/test_dry_run.py -x -v`)
- [x] All 23 helpers tests pass (`poetry run pytest
backend/copilot/tools/helpers_test.py -x -v`)
- [x] All 13 run_block tests pass (`poetry run pytest
backend/copilot/tools/run_block_test.py -x -v`)
- [x] Backend linting passes (ruff check + format)
- [x] Frontend linting passes (next lint)
- [ ] Manual: trigger dry-run on a block with error output pin (e.g.
Komodo Image Generator) — should show "Simulated" status with clean
output, no misleading "error" section
- [ ] Manual: first click on Simulate button should immediately show
results (no race condition)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-03-31 21:03:00 +00:00
Zamil Majdy
80581a8364 fix(copilot): add tool call circuit breakers and intermediate persistence (#12604)
## Why

CoPilot session `d2f7cba3` took **82 minutes** and cost **$20.66** for a
single user message. Root causes:
1. Redis session meta key expired after 1h, making the session invisible
to the resume endpoint — causing empty page on reload
2. Redis stream key also expired during sub-agent gaps (task_progress
events produced no chunks)
3. No intermediate persistence — session messages only saved to DB after
the entire turn completes
4. Sub-agents retried similar WebSearch queries (addressed via prompt
guidance)

## What

### Redis TTL fixes (root cause of empty session on reload)
- `publish_chunk()` now periodically refreshes **both** the session meta
key AND stream key TTL (every 60s).
- `task_progress` SDK events now emit `StreamHeartbeat` chunks, ensuring
`publish_chunk` is called even during long sub-agent gaps where no real
chunks are produced.
- Without this fix, turns exceeding the 1h `stream_ttl` lose their
"running" status and stream data, making `get_active_session()` return
False.

### Intermediate DB persistence
- Session messages flushed to DB every **30 seconds** or **10 new
messages** during the stream loop.
- Uses `asyncio.shield(upsert_chat_session())` matching the existing
`finally` block pattern.

### Orphaned message cleanup on rollback
- On stream attempt rollback, orphaned messages persisted by
intermediate flushes are now cleaned up from the DB via
`delete_messages_from_sequence`.
- Prevents stale messages from resurfacing on page reload after a failed
retry.

### Prompt guidance
- Added web search best practices to code supplement (search efficiency,
sub-agent scope separation).

### Approach: root cause fixes, not capability limits
- **No tool call caps** — artificial limits on WebSearch or total tool
calls would reduce autopilot capability without addressing why searches
were redundant.
- **Task tool remains enabled** — sub-agent delegation via Task is a
core capability. The existing `max_subtasks` concurrency guard is
sufficient.
- The real fixes (TTL refresh, persistence, prompt guidance) address the
underlying bugs and behavioral issues.

## How

### Files changed
- `stream_registry.py` — Redis meta + stream key TTL refresh in
`publish_chunk()`, module-level keepalive tracker
- `response_adapter.py` — `task_progress` SystemMessage →
StreamHeartbeat emission
- `service.py` — Intermediate DB persistence in `_run_stream_attempt`
stream loop, orphan cleanup on rollback
- `db.py` — `delete_messages_from_sequence` for rollback cleanup
- `prompting.py` — Web search best practices

### GCP log evidence
```
# Meta key expired during 82-min turn:
09:49 — GET_SESSION: active_session=False, msg_count=1  ← meta gone
10:18 — Session persisted in finally with 189 messages   ← turn completed

# T13 (1h45min) same bug reproduced live:
16:20 — task_progress events still arriving, but active_session=False

# Actual cost:
Turn usage: cache_read=347916, cache_create=212472, output=12375, cost_usd=20.66
```

### Test plan
- [x] task_progress emits StreamHeartbeat
- [x] Task background blocked, foreground allowed, slot release on
completion/failure
- [x] CI green (lint, type-check, tests, e2e, CodeQL)

---------

Co-authored-by: Zamil Majdy <majdy.zamil@gmail.com>
2026-03-31 21:01:56 +00:00
lif
3c046eb291 fix(frontend): show all agent outputs instead of only the last one (#12504)
Fixes #9175

### Changes 🏗️

The Agent Outputs panel only displayed the last execution result per
output node, discarding all prior outputs during a run.

**Root cause:** In `AgentOutputs.tsx`, the `outputs` useMemo extracted
only the last element from `nodeExecutionResults`:
```tsx
const latestResult = executionResults[executionResults.length - 1];
```

**Fix:** Changed `.map()` to `.flatMap()` over output nodes, iterating
through all `executionResults` for each node. Each execution result now
gets its own renderer lookup and metadata entry, so the panel shows
every output produced during the run.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified TypeScript compiles without errors
- [x] Confirmed the flatMap logic correctly iterates all execution
results
  - [x] Verified existing filter for null renderers is preserved
- [x] Run an agent with multiple outputs and confirm all show in the
panel

---------

Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 20:31:12 +00:00
Zamil Majdy
3e25488b2d feat(copilot): add session-level dry_run flag to autopilot sessions (#12582)
## Summary
- Adds a session-level `dry_run` flag that forces ALL tool calls
(`run_block`, `run_agent`) in a copilot/autopilot session to use dry-run
simulation mode
- Stores the flag in a typed `ChatSessionMetadata` JSON model on the
`ChatSession` DB row, accessed via `session.dry_run` property
- Adds `dry_run` to the AutoPilot block Input schema so graph builders
can create dry-run autopilot nodes
- Refactors multiple copilot tools from `**kwargs` to explicit
parameters for type safety

## Changes
- **Prisma schema**: Added `metadata` JSON column to `ChatSession` model
with migration
- **Python models**: Added `ChatSessionMetadata` model with `dry_run`
field, added `metadata` field to `ChatSessionInfo` and `ChatSession`,
updated `from_db()`, `new()`, and `create_chat_session()`
- **Session propagation**: `set_execution_context(user_id, session)`
called from `baseline/service.py` so tool handlers can read
session-level flags via `session.dry_run`
- **Tool enforcement**: `run_block` and `run_agent` check
`session.dry_run` and force `dry_run=True` when set; `run_agent` blocks
scheduling in dry-run sessions
- **AutoPilot block**: Added `dry_run` input field, passes it when
creating sessions
- **Chat API**: Added `CreateSessionRequest` model with `dry_run` field
to `POST /sessions` endpoint; added `metadata` to session responses
- **Frontend**: Updated `useChatSession.ts` to pass body to the create
session mutation
- **Tool refactoring**: Multiple copilot tools refactored from
`**kwargs` to explicit named parameters (agent_browser, manage_folders,
workspace_files, connect_integration, agent_output, bash_exec, etc.) for
better type safety

## Test plan
- [x] Unit tests for `ChatSession.new()` with dry_run parameter
- [x] Unit tests for `RunBlockTool` session dry_run override
- [x] Unit tests for `RunAgentTool` session dry_run override
- [x] Unit tests for session dry_run blocks scheduling
- [x] Existing dry_run tests still pass (12/12)
- [x] Existing permissions tests still pass
- [x] All pre-commit hooks pass (ruff, isort, pyright, tsc)
- [ ] Manual: Create autopilot session with `dry_run=True`, verify
run_block/run_agent calls use simulation

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 16:27:36 +00:00
Abhimanyu Yadav
57b17dc8e1 feat(platform): generic managed credential system with AgentMail auto-provisioning (#12537)
### Why / What / How

**Why:** We need a third credential type: **system-provided but unique
per user** (managed credentials). Currently we have system credentials
(same for all users) and user credentials (user provides their own
keys). Managed credentials bridge the gap — the platform provisions them
automatically, one per user, for integrations like AgentMail where each
user needs their own pod-scoped API key.

**What:**
- Generic **managed credential provider registry** — any integration can
register a provider that auto-provisions per-user credentials
- **AgentMail** is the first consumer: creates a pod + pod-scoped API
key using the org-level API key
- Managed credentials appear in the credential dropdown like normal API
keys but with `autogpt_managed=True` — users **cannot update or delete**
them
- **Auto-provisioning** on `GET /credentials` — lazily creates managed
credentials when users browse their credential list
- **Account deletion cleanup** utility — revokes external resources
(pods, API keys) before user deletion
- **Frontend UX** — hides the delete button for managed credentials on
the integrations page

**How:**

### Backend

**New files:**
- `backend/integrations/managed_credentials.py` —
`ManagedCredentialProvider` ABC, global registry,
`ensure_managed_credentials()` (with per-user asyncio lock +
`asyncio.gather` for concurrency), `cleanup_managed_credentials()`
- `backend/integrations/managed_providers/__init__.py` —
`register_all()` called at startup
- `backend/integrations/managed_providers/agentmail.py` —
`AgentMailManagedProvider` with `provision()` (creates pod + API key via
agentmail SDK) and `deprovision()` (deletes pod)

**Modified files:**
- `credentials_store.py` — `autogpt_managed` guards on update/delete,
`has_managed_credential()` / `add_managed_credential()` helpers
- `model.py` — `autogpt_managed: bool` + `metadata: dict` on
`_BaseCredentials`
- `router.py` — calls `ensure_managed_credentials()` in list endpoints,
removed explicit `/agentmail/connect` endpoint
- `user.py` — `cleanup_user_managed_credentials()` for account deletion
- `rest_api.py` — registers managed providers at startup
- `settings.py` — `agentmail_api_key` setting

### Frontend
- Added `autogpt_managed` to `CredentialsMetaResponse` type
- Conditionally hides delete button on integrations page for managed
credentials

### Key design decisions
- **Auto-provision in API layer, not data layer** — keeps
`get_all_creds()` side-effect-free
- **Race-safe** — per-(user, provider) asyncio lock with double-check
pattern prevents duplicate pods
- **Idempotent** — AgentMail SDK `client_id` ensures pod creation is
idempotent; `add_managed_credential()` uses upsert under Redis lock
- **Error-resilient** — provisioning failures are logged but never block
credential listing

### Changes 🏗️

| File | Action | Description |
|------|--------|-------------|
| `backend/integrations/managed_credentials.py` | NEW | ABC, registry,
ensure/cleanup |
| `backend/integrations/managed_providers/__init__.py` | NEW | Registers
all providers at startup |
| `backend/integrations/managed_providers/agentmail.py` | NEW |
AgentMail provisioning/deprovisioning |
| `backend/integrations/credentials_store.py` | MODIFY | Guards +
managed credential helpers |
| `backend/data/model.py` | MODIFY | `autogpt_managed` + `metadata`
fields |
| `backend/api/features/integrations/router.py` | MODIFY |
Auto-provision on list, removed `/agentmail/connect` |
| `backend/data/user.py` | MODIFY | Account deletion cleanup |
| `backend/api/rest_api.py` | MODIFY | Provider registration at startup
|
| `backend/util/settings.py` | MODIFY | `agentmail_api_key` setting |
| `frontend/.../integrations/page.tsx` | MODIFY | Hide delete for
managed creds |
| `frontend/.../types.ts` | MODIFY | `autogpt_managed` field |

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] 23 tests pass in `router_test.py` (9 new tests for
ensure/cleanup/auto-provisioning)
  - [x] `poetry run format && poetry run lint` — clean
  - [x] OpenAPI schema regenerated
- [x] Manual: verify managed credential appears in AgentMail block
dropdown
  - [x] Manual: verify delete button hidden for managed credentials
- [x] Manual: verify managed credential cannot be deleted via API (403)

#### For configuration changes:
- [x] `.env.default` is updated with `AGENTMAIL_API_KEY=`

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:56:18 +00:00
Krishna Chaitanya
a20188ae59 fix(blocks): validate non-empty input in AIConversationBlock before LLM call (#12545)
### Why / What / How

**Why:** When `AIConversationBlock` receives an empty messages list and
an empty prompt, the block blindly forwards the empty array to the
downstream LLM API, which returns a cryptic `400 Bad Request` error:
`"Invalid 'messages': empty array. Expected an array with minimum length
1."` This is confusing for users who don't understand why their agent
failed.

**What:** Add early input validation in `AIConversationBlock.run()` that
raises a clear `ValueError` when both `messages` and `prompt` are empty.
Also add three unit tests covering the validation logic.

**How:** A simple guard clause at the top of the `run` method checks `if
not input_data.messages and not input_data.prompt` before the LLM call
is made. If both are empty, a descriptive `ValueError` is raised. If
either one has content, the block proceeds normally.

### Changes

- `autogpt_platform/backend/backend/blocks/llm.py`: Add validation guard
in `AIConversationBlock.run()` to reject empty messages + empty prompt
before calling the LLM
- `autogpt_platform/backend/backend/blocks/test/test_llm.py`: Add
`TestAIConversationBlockValidation` with three tests:
- `test_empty_messages_and_empty_prompt_raises_error` — validates the
guard clause
- `test_empty_messages_with_prompt_succeeds` — ensures prompt-only usage
still works
- `test_nonempty_messages_with_empty_prompt_succeeds` — ensures
messages-only usage still works

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Lint passes (`ruff check`)
  - [x] Formatting passes (`ruff format`)
- [x] New unit tests validate the empty-input guard and the happy paths

Closes #11875

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:43:42 +00:00
goingforstudying-ctrl
c410be890e fix: add empty choices guard in extract_openai_tool_calls() (#12540)
## Summary

`extract_openai_tool_calls()` in `llm.py` crashes with `IndexError` when
the LLM provider returns a response with an empty `choices` list.

### Changes 🏗️

- Added a guard check `if not response.choices: return None` before
accessing `response.choices[0]`
- This is consistent with the function's existing pattern of returning
`None` when no tool calls are found

### Bug Details

When an LLM provider returns a response with an empty choices list
(e.g., due to content filtering, rate limiting, or API errors),
`response.choices[0]` raises `IndexError`. This can crash the entire
agent execution pipeline.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- Verified that the function returns `None` when `response.choices` is
empty
- Verified existing behavior is unchanged when `response.choices` is
non-empty

---------

Co-authored-by: goingforstudying-ctrl <forgithubuse@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 20:10:27 +07:00
Zamil Majdy
37d9863552 feat(platform): add extended thinking execution mode to OrchestratorBlock (#12512)
## Summary
- Adds `ExecutionMode` enum with `BUILT_IN` (default built-in tool-call
loop) and `EXTENDED_THINKING` (delegates to Claude Agent SDK for richer
reasoning)
- Extracts shared `tool_call_loop` into `backend/util/tool_call_loop.py`
— reusable by both OrchestratorBlock agent mode and copilot baseline
- Refactors copilot baseline to use the shared `tool_call_loop` with
callback-driven iteration

## ExecutionMode enum
`ExecutionMode` (`backend/blocks/orchestrator.py`) controls how
OrchestratorBlock executes tool calls:
- **`BUILT_IN`** — Default mode. Runs the built-in tool-call loop
(supports all LLM providers).
- **`EXTENDED_THINKING`** — Delegates to the Claude Agent SDK for
extended thinking and multi-step planning. Requires Anthropic-compatible
providers (`anthropic` / `open_router`) and direct API credentials
(subscription mode not supported). Validates both provider and model
name at runtime.

## Shared tool_call_loop
`backend/util/tool_call_loop.py` provides a generic, provider-agnostic
conversation loop:
1. Call LLM with tools → 2. Extract tool calls → 3. Execute tools → 4.
Update conversation → 5. Repeat

Callers provide three callbacks:
- `llm_call`: wraps any LLM provider (OpenAI streaming, Anthropic,
llm.llm_call, etc.)
- `execute_tool`: wraps any tool execution (TOOL_REGISTRY, graph block
execution, etc.)
- `update_conversation`: formats messages for the specific protocol

## OrchestratorBlock EXTENDED_THINKING mode
- `_create_graph_mcp_server()` converts graph-connected blocks to MCP
tools
- `_execute_tools_sdk_mode()` runs `ClaudeSDKClient` with those MCP
tools
- Agent mode refactored to use shared `tool_call_loop`

## Copilot baseline refactored
- Streaming callbacks buffer `Stream*` events during loop execution
- Events are drained after `tool_call_loop` returns
- Same conversation logic, less code duplication

## SDK environment builder extraction
- `build_sdk_env()` extracted to `backend/copilot/sdk/env.py` for reuse
by both copilot SDK service and OrchestratorBlock

## Provider validation
EXTENDED_THINKING mode validates `provider in ('anthropic',
'open_router')` and `model_name.startswith('claude')` because the Claude
Agent SDK requires an Anthropic API key or OpenRouter key. Subscription
mode is not supported — it uses the platform's internal credit system
which doesn't provide raw API keys needed by the SDK. The validation
raises a clear `ValueError` if an unsupported provider or model is used.

## PR Dependencies
This PR builds on #12511 (Claude SDK client). It can be reviewed
independently — #12511 only adds the SDK client module which this PR
imports. If #12511 merges first, this PR will have no conflicts.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All pre-commit hooks pass (typecheck, lint, format)
  - [x] Existing OrchestratorBlock tests still pass
- [x] Copilot baseline behavior unchanged (same stream events, same tool
execution)
- [x] Manual: OrchestratorBlock with execution_mode=EXTENDED_THINKING +
downstream blocks → SDK calls tools
  - [x] Agent mode regression test (non-SDK path works as before)
  - [x] SDK mode error handling (invalid provider raises ValueError)
2026-03-31 20:04:13 +07:00
Krishna Chaitanya
2f42ff9b47 fix(blocks): validate email recipients in Gmail blocks before API call (#12546)
### Why / What / How

**Why:** When a user or LLM supplies a malformed recipient string (e.g.
a bare username, a JSON blob, or an empty value) to `GmailSendBlock`,
`GmailCreateDraftBlock`, or any reply block, the Gmail API returns an
opaque `HttpError 400: "Invalid To header"`. This surfaces as a
`BlockUnknownError` with no actionable guidance, making it impossible
for the LLM to self-correct. (Fixes #11954)

**What:** Adds a lightweight `validate_email_recipients()` function that
checks every recipient against a simplified RFC 5322 pattern
(`local@domain.tld`) and raises a clear `ValueError` listing all invalid
entries before any API call is made.

**How:** The validation is called in two shared code paths —
`create_mime_message()` (used by send and draft blocks) and
`_build_reply_message()` (used by reply blocks) — so all Gmail blocks
that compose outgoing email benefit from it with zero per-block changes.
The regex is intentionally permissive (any `x@y.z` passes) to avoid
false positives on unusual but valid addresses.

### Changes 🏗️

- Added `validate_email_recipients()` helper in `gmail.py` with a
compiled regex
- Hooked validation into `create_mime_message()` for `to`, `cc`, and
`bcc` fields
- Hooked validation into `_build_reply_message()` for reply/draft-reply
blocks
- Added `TestValidateEmailRecipients` test class covering valid,
invalid, mixed, empty, JSON-string, and field-name scenarios

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified `validate_email_recipients` correctly accepts valid
emails (`user@example.com`, `a@b.com`, `test@sub.domain.co`)
- [x] Verified it rejects malformed entries (bare names, missing domain
dot, empty strings, JSON strings)
- [x] Verified error messages include the field name and all invalid
entries
  - [x] Verified empty recipient lists pass without error
  - [x] Confirmed `gmail.py` and test file parse correctly (AST check)

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:37:33 +00:00
Zamil Majdy
914efc53e5 fix(backend): disambiguate duplicate tool names in OrchestratorBlock (#12555)
## Why
The OrchestratorBlock fails with `Tool names must be unique` when
multiple nodes use the same block type (e.g., two "Web Search" blocks
connected as tools). The Anthropic API rejects the request because
duplicate tool names are sent.

## What
- Detect duplicate tool names after building tool signatures
- Append `_1`, `_2`, etc. suffixes to disambiguate
- Enrich descriptions of duplicate tools with their hardcoded default
values so the LLM can distinguish between them
- Clean up internal `_hardcoded_defaults` metadata before sending to API
- Exclude sensitive/credential fields from default value descriptions

## How
- After `_create_tool_node_signatures` builds all tool functions, count
name occurrences
- For duplicates: rename with suffix and append `[Pre-configured:
key=value]` to description using the node's `input_default` (excluding
linked fields that the LLM provides)
- Added defensive `isinstance(defaults, dict)` check for compatibility
with test mocks
- Suffix collision avoidance: skips candidates that collide with
existing tool names
- Long tool names truncated to fit within 64-character API limit
- 47 unit tests covering: basic dedup, description enrichment, unique
names unchanged, no metadata leaks, single tool, triple duplicates,
linked field exclusion, mixed unique/duplicate scenarios, sensitive
field exclusion, long name truncation, suffix collision, malformed
tools, missing description, empty list, 10-tool all-same-name, multiple
distinct groups, large default truncation, suffix collision cascade,
parameter preservation, boundary name lengths, nested dict/list
defaults, null defaults, customized name priority, required fields

## Test plan
- [x] All 47 tests in `test_orchestrator_tool_dedup.py` pass
- [x] All 11 existing orchestrator unit tests pass (dict, dynamic
fields, responses API)
- [x] Pre-commit hooks pass (ruff, black, isort, pyright)
- [ ] Manual test: connect two same-type blocks to an orchestrator and
verify the LLM call succeeds

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 11:54:10 +00:00
2540 changed files with 61723 additions and 826068 deletions

1
.agents/skills Symbolic link
View File

@@ -0,0 +1 @@
../.claude/skills

10
.claude/settings.json Normal file
View File

@@ -0,0 +1,10 @@
{
"permissions": {
"allowedTools": [
"Read", "Grep", "Glob",
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
]
}
}

View File

@@ -6,11 +6,19 @@ on:
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -19,47 +27,22 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/original_autogpt
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -71,41 +54,23 @@ jobs:
git config --global user.name "Auto-GPT-Bot"
git config --global user.email "github-bot@agpt.co"
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
@@ -116,12 +81,13 @@ jobs:
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
original_autogpt/tests/unit original_autogpt/tests/integration
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -135,11 +101,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}
flags: autogpt-agent
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/original_autogpt/logs/
path: classic/logs/

View File

@@ -148,7 +148,7 @@ jobs:
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
--numprocesses=4 --durations=10 \
tests/unit tests/integration 2>&1 | tee test_output.txt
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
test_failure=${PIPESTATUS[0]}

View File

@@ -10,10 +10,9 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!**/*.md'
pull_request:
branches: [ master, dev, release-* ]
@@ -21,10 +20,9 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!**/*.md'
defaults:
@@ -35,13 +33,9 @@ defaults:
jobs:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [ original_autogpt ]
fail-fast: false
timeout-minutes: 20
env:
min-python-version: '3.10'
min-python-version: '3.12'
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -55,22 +49,22 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./classic/${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
- name: Install dependencies
run: poetry install
- name: Run smoke tests with direct-benchmark
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
poetry run agbenchmark --test=WriteFile
poetry run direct-benchmark run \
--strategies one_shot \
--models claude \
--tests ReadFile,WriteFile \
--json
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AGENT_NAME: ${{ matrix.agent-name }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
HELICONE_CACHE_ENABLED: false
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
TELEMETRY_ENVIRONMENT: autogpt-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"
CI: true

View File

@@ -1,18 +1,24 @@
name: Classic - AGBenchmark CI
name: Classic - Direct Benchmark CI
on:
push:
branches: [ master, dev, ci-test* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -23,95 +29,16 @@ defaults:
shell: bash
env:
min-python-version: '3.10'
min-python-version: '3.12'
jobs:
test:
permissions:
contents: read
benchmark-tests:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: classic/benchmark
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
self-test-with-agent:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [forge]
fail-fast: false
timeout-minutes: 20
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -124,53 +51,120 @@ jobs:
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python -
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
- name: Run basic benchmark tests
run: |
echo "Testing ReadFile challenge with one_shot strategy..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests ReadFile \
--json
echo "Testing WriteFile challenge..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test category filtering
run: |
echo "Testing coding category..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--categories coding \
--tests ReadFile,WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test multiple strategies
run: |
echo "Testing multiple strategies..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot,plan_execute \
--models claude \
--tests ReadFile \
--parallel 2 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
# Run regression tests on maintain challenges
regression-tests:
runs-on: ubuntu-latest
timeout-minutes: 45
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
defaults:
run:
shell: bash
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
- name: Run regression tests
working-directory: classic
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
set +e # Ignore non-zero exit codes and continue execution
echo "Running the following command: poetry run agbenchmark --maintain --mock"
poetry run agbenchmark --maintain --mock
EXIT_CODE=$?
set -e # Stop ignoring non-zero exit codes
# Check if the exit code was 5, and if so, exit with 0 instead
if [ $EXIT_CODE -eq 5 ]; then
echo "regression_tests.json is empty."
fi
echo "Running the following command: poetry run agbenchmark --mock"
poetry run agbenchmark --mock
echo "Running the following command: poetry run agbenchmark --mock --category=data"
poetry run agbenchmark --mock --category=data
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
poetry run agbenchmark --mock --category=coding
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
# poetry run agbenchmark --test=WriteFile
cd ../benchmark
poetry install
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
export BUILD_SKILL_TREE=true
# poetry run agbenchmark --mock
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
# if [ ! -z "$CHANGED" ]; then
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
# echo "$CHANGED"
# exit 1
# else
# echo "No unstaged changes."
# fi
echo "Running regression tests (previously beaten challenges)..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--maintain \
--parallel 4 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"

View File

@@ -6,13 +6,15 @@ on:
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -21,131 +23,60 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/forge
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Checkout cassettes
if: ${{ startsWith(github.event_name, 'pull_request') }}
env:
PR_BASE: ${{ github.event.pull_request.base.ref }}
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
cassette_base_branch="${PR_BASE}"
cd tests/vcr_cassettes
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
cassette_base_branch="master"
fi
if git ls-remote --exit-code --heads origin $cassette_branch ; then
git fetch origin $cassette_branch
git fetch origin $cassette_base_branch
git checkout $cassette_branch
# Pick non-conflicting cassette updates from the base branch
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
echo "Using cassettes from mirror branch '$cassette_branch'," \
"synced to upstream branch '$cassette_base_branch'."
else
git checkout -b $cassette_branch
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
"Using cassettes from '$cassette_base_branch'."
fi
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
- name: Install Playwright browsers
run: poetry run playwright install chromium
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
forge/forge forge/tests
env:
CI: true
PLAIN_OUTPUT: True
# API keys - tests that need these will skip if not available
# Secrets are not available to fork PRs (GitHub security feature)
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -159,85 +90,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}
- id: setup_git_auth
name: Set up git token authentication
# Cassettes may be pushed even when tests fail
if: success() || failure()
run: |
config_key="http.${{ github.server_url }}/.extraheader"
if [ "${{ runner.os }}" = 'macOS' ]; then
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
else
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
fi
git config "$config_key" \
"Authorization: Basic $base64_pat"
cd tests/vcr_cassettes
git config "$config_key" \
"Authorization: Basic $base64_pat"
echo "config_key=$config_key" >> $GITHUB_OUTPUT
- id: push_cassettes
name: Push updated cassettes
# For pull requests, push updated cassettes even when tests fail
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
env:
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
is_pull_request=true
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
else
cassette_branch="${{ github.ref_name }}"
fi
cd tests/vcr_cassettes
# Commit & push changes to cassettes if any
if ! git diff --quiet; then
git add .
git commit -m "Auto-update cassettes"
git push origin HEAD:$cassette_branch
if [ ! $is_pull_request ]; then
cd ../..
git add tests/vcr_cassettes
git commit -m "Update cassette submodule"
git push origin HEAD:$cassette_branch
fi
echo "updated=true" >> $GITHUB_OUTPUT
else
echo "updated=false" >> $GITHUB_OUTPUT
echo "No cassette changes to commit"
fi
- name: Post Set up git token auth
if: steps.setup_git_auth.outcome == 'success'
run: |
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
- name: Apply "behaviour change" label and comment on PR
if: ${{ startsWith(github.event_name, 'pull_request') }}
run: |
PR_NUMBER="${{ github.event.pull_request.number }}"
TOKEN="${{ secrets.PAT_REVIEW }}"
REPO="${{ github.repository }}"
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
echo "Adding label and comment..."
echo $TOKEN | gh auth login --with-token
gh issue edit $PR_NUMBER --add-label "behaviour change"
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
flags: forge
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/forge/logs/
path: classic/logs/

View File

@@ -1,60 +0,0 @@
name: Classic - Frontend CI/CD
on:
push:
branches:
- master
- dev
- 'ci-test*' # This will match any branch that starts with "ci-test"
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
pull_request:
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
jobs:
build:
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Setup Flutter
uses: subosito/flutter-action@v2
with:
flutter-version: '3.13.2'
- name: Build Flutter to Web
run: |
cd classic/frontend
flutter build web --base-href /app/
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
# if: github.event_name == 'push'
# run: |
# git config --local user.email "action@github.com"
# git config --local user.name "GitHub Action"
# git add classic/frontend/build/web
# git checkout -B ${{ env.BUILD_BRANCH }}
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
# git push -f origin ${{ env.BUILD_BRANCH }}
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}
branch: ${{ env.BUILD_BRANCH }}
delete-branch: true
title: "Update frontend build in `${{ github.ref_name }}`"
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
commit-message: "Update frontend build based on commit ${{ github.sha }}"

View File

@@ -7,7 +7,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
@@ -16,7 +18,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
@@ -27,44 +31,13 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
original_autogpt:
- classic/original_autogpt/autogpt/**
- classic/original_autogpt/tests/**
- classic/original_autogpt/poetry.lock
forge:
- classic/forge/forge/**
- classic/forge/tests/**
- classic/forge/poetry.lock
benchmark:
- classic/benchmark/agbenchmark/**
- classic/benchmark/tests/**
- classic/benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -81,42 +54,31 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: classic/${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -133,19 +95,16 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: classic/${{ matrix.sub-package }}

View File

@@ -269,12 +269,14 @@ jobs:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- name: Run pytest
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
else
poetry run pytest -s -vv
poetry run pytest -s -vv \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
fi
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
@@ -287,11 +289,13 @@ jobs:
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-backend
files: ./autogpt_platform/backend/coverage.xml
env:
CI: true

View File

@@ -148,3 +148,11 @@ jobs:
- name: Run Integration Tests
run: pnpm test:unit
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml

10
.gitignore vendored
View File

@@ -3,6 +3,7 @@
classic/original_autogpt/keys.py
classic/original_autogpt/*.json
auto_gpt_workspace/*
.autogpt/
*.mpeg
.env
# Root .env files
@@ -16,6 +17,7 @@ log-ingestion.txt
/logs
*.log
*.mp3
!autogpt_platform/frontend/public/notification.mp3
mem.sqlite3
venvAutoGPT
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
# AgBenchmark
classic/benchmark/agbenchmark/reports/
classic/reports/
classic/direct_benchmark/reports/
classic/.benchmark_workspaces/
classic/direct_benchmark/.benchmark_workspaces/
# Nodejs
package-lock.json
@@ -177,9 +183,13 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
**/.claude/settings.local.json
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
# Test database
test.db
.next
# Implementation plans (generated by AI agents)
plans/

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes

View File

@@ -84,51 +84,16 @@ repos:
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
name: Check & Install dependencies - Classic
alias: poetry-install-classic
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
# include forge source (since it's a path dependency)
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
poetry -C classic install
'
always_run: true
language: system
@@ -223,26 +188,10 @@ repos:
language: system
- id: isort
name: Lint (isort) - Classic - AutoGPT
alias: isort-classic-autogpt
entry: poetry -P classic/original_autogpt run isort -p autogpt
files: ^classic/original_autogpt/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Forge
alias: isort-classic-forge
entry: poetry -P classic/forge run isort -p forge
files: ^classic/forge/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Benchmark
alias: isort-classic-benchmark
entry: poetry -P classic/benchmark run isort -p agbenchmark
files: ^classic/benchmark/
name: Lint (isort) - Classic
alias: isort-classic
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
files: ^classic/(original_autogpt|forge|direct_benchmark)/
types: [file, python]
language: system
@@ -256,26 +205,13 @@ repos:
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
# Use consolidated flake8 config at classic/.flake8
hooks:
- id: flake8
name: Lint (Flake8) - Classic - AutoGPT
alias: flake8-classic-autogpt
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
args: [--config=classic/original_autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Forge
alias: flake8-classic-forge
files: ^classic/forge/(forge|tests)/
args: [--config=classic/forge/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Benchmark
alias: flake8-classic-benchmark
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
name: Lint (Flake8) - Classic
alias: flake8-classic
files: ^classic/(original_autogpt|forge|direct_benchmark)/
args: [--config=classic/.flake8]
- repo: local
hooks:
@@ -311,29 +247,10 @@ repos:
pass_filenames: false
- id: pyright
name: Typecheck - Classic - AutoGPT
alias: pyright-classic-autogpt
entry: poetry -C classic/original_autogpt run pyright
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Forge
alias: pyright-classic-forge
entry: poetry -C classic/forge run pyright
files: ^classic/forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Benchmark
alias: pyright-classic-benchmark
entry: poetry -C classic/benchmark run pyright
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
name: Typecheck - Classic
alias: pyright-classic
entry: poetry -C classic run pyright
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
types: [file]
language: system
pass_filenames: false
@@ -360,26 +277,9 @@ repos:
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - AutoGPT (excl. slow tests)
# alias: pytest-classic-autogpt
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# # include forge source (since it's a path dependency) but exclude *_test.py files:
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Forge (excl. slow tests)
# alias: pytest-classic-forge
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Benchmark
# alias: pytest-classic-benchmark
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
# name: Run tests - Classic (excl. slow tests)
# alias: pytest-classic
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
# language: system
# pass_filenames: false

View File

@@ -1,6 +1,6 @@
# AutoGPT Platform Contribution Guide
This guide provides context for Codex when updating the **autogpt_platform** folder.
This guide provides context for coding agents when updating the **autogpt_platform** folder.
## Directory overview

1
CLAUDE.md Normal file
View File

@@ -0,0 +1 @@
@AGENTS.md

120
autogpt_platform/AGENTS.md Normal file
View File

@@ -0,0 +1,120 @@
# AutoGPT Platform
This file provides guidance to coding agents when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1,120 +1 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.
@AGENTS.md

View File

@@ -178,6 +178,7 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -0,0 +1,227 @@
# Backend
This file provides guidance to coding agents when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -1,227 +1 @@
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
@AGENTS.md

View File

@@ -31,7 +31,10 @@ from backend.data.model import (
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -618,6 +621,11 @@ async def delete_credential(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)
return ChatSession.new(user_id, dry_run=False)
@tools_router.post(

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -20,6 +20,7 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -112,12 +113,25 @@ class StreamChatRequest(BaseModel):
) # Workspace file IDs attached to this message
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -138,6 +152,7 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -248,6 +263,7 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -256,22 +272,28 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id)
session = await create_chat_session(user_id, dry_run=dry_run)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -420,6 +442,7 @@ async def get_session(
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=session.metadata,
)
@@ -527,6 +550,8 @@ async def reset_copilot_usage(
try:
# Verify the user is actually at or over their daily limit.
# (rate_limit_reset_cost intentionally omitted — this object is only
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
@@ -1174,7 +1199,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id)
session = await create_chat_session(health_check_user_id, dry_run=False)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -469,3 +469,60 @@ def test_suggested_prompts_empty_prompts(
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422

View File

@@ -40,11 +40,15 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -110,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -148,6 +153,7 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@@ -224,6 +230,9 @@ async def callback(
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -238,6 +247,7 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -332,6 +342,11 @@ async def delete_credentials(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
@@ -342,6 +357,11 @@ async def delete_credentials(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)

View File

@@ -1,6 +1,7 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from unittest.mock import AsyncMock, patch
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
@@ -276,3 +277,294 @@ class TestCreateCredentialNoSecretInResponse:
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -481,6 +481,11 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
},
},
include=library_agent_include(

View File

@@ -12,6 +12,7 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
@pytest_asyncio.fixture

View File

@@ -0,0 +1,61 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.v1 import v1_router
app = fastapi.FastAPI()
app.include_router(v1_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_onboarding_profile_success(mocker):
mock_extract = mocker.patch(
"backend.api.features.v1.extract_business_understanding",
new_callable=AsyncMock,
)
mock_upsert = mocker.patch(
"backend.api.features.v1.upsert_business_understanding",
new_callable=AsyncMock,
)
from backend.data.understanding import BusinessUnderstandingInput
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
user_name="John",
user_role="Founder/CEO",
pain_points=["Finding leads"],
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
)
mock_upsert.return_value = AsyncMock()
response = client.post(
"/onboarding/profile",
json={
"user_name": "John",
"user_role": "Founder/CEO",
"pain_points": ["Finding leads", "Email & outreach"],
},
)
assert response.status_code == 200
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once()
def test_onboarding_profile_missing_fields():
response = client.post(
"/onboarding/profile",
json={"user_name": "John"},
)
assert response.status_code == 422

View File

@@ -1 +0,0 @@
# Platform bot linking API

View File

@@ -1,35 +0,0 @@
"""Bot API key authentication for platform linking endpoints."""
import hmac
import os
from fastapi import HTTPException, Request
from backend.util.settings import Settings
async def get_bot_api_key(request: Request) -> str | None:
"""Extract the bot API key from the X-Bot-API-Key header."""
return request.headers.get("x-bot-api-key")
def check_bot_api_key(api_key: str | None) -> None:
"""Validate the bot API key. Uses constant-time comparison.
Reads the key from env on each call so rotated secrets take effect
without restarting the process.
"""
configured_key = os.getenv("PLATFORM_BOT_API_KEY", "")
if not configured_key:
settings = Settings()
if settings.config.enable_auth:
raise HTTPException(
status_code=503,
detail="Bot API key not configured.",
)
# Auth disabled (local dev) — allow without key
return
if not api_key or not hmac.compare_digest(api_key, configured_key):
raise HTTPException(status_code=401, detail="Invalid bot API key.")

View File

@@ -1,170 +0,0 @@
"""
Bot Chat Proxy endpoints.
Allows the bot service to send messages to CoPilot on behalf of
linked users, authenticated via bot API key.
"""
import asyncio
import logging
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from prisma.models import PlatformLink
from backend.copilot import stream_registry
from backend.copilot.executor.utils import enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
append_and_save_message,
create_chat_session,
get_chat_session,
)
from backend.copilot.response_model import StreamFinish
from .auth import check_bot_api_key, get_bot_api_key
from .models import BotChatRequest, BotChatSessionResponse
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/chat/session",
response_model=BotChatSessionResponse,
summary="Create a CoPilot session for a linked user (bot-facing)",
)
async def bot_create_session(
request: BotChatRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> BotChatSessionResponse:
"""Creates a new CoPilot chat session on behalf of a linked user."""
check_bot_api_key(x_bot_api_key)
link = await PlatformLink.prisma().find_first(where={"userId": request.user_id})
if not link:
raise HTTPException(status_code=404, detail="User has no platform links.")
session = await create_chat_session(request.user_id)
return BotChatSessionResponse(session_id=session.session_id)
@router.post(
"/chat/stream",
summary="Stream a CoPilot response for a linked user (bot-facing)",
)
async def bot_chat_stream(
request: BotChatRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
):
"""
Send a message to CoPilot on behalf of a linked user and stream
the response back as Server-Sent Events.
The bot authenticates with its API key — no user JWT needed.
"""
check_bot_api_key(x_bot_api_key)
user_id = request.user_id
# Verify user has a platform link
link = await PlatformLink.prisma().find_first(where={"userId": user_id})
if not link:
raise HTTPException(status_code=404, detail="User has no platform links.")
# Get or create session
session_id = request.session_id
if session_id:
session = await get_chat_session(session_id, user_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found.")
else:
session = await create_chat_session(user_id)
session_id = session.session_id
# Save user message
message = ChatMessage(role="user", content=request.message)
await append_and_save_message(session_id, message)
# Create a turn and enqueue
turn_id = str(uuid4())
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=True,
)
logger.info(
"Bot chat: user ...%s, session %s, turn %s",
user_id[-8:],
session_id,
turn_id,
)
async def event_generator():
subscriber_queue = None
try:
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if isinstance(chunk, str):
yield chunk
else:
yield chunk.to_sse()
if isinstance(chunk, StreamFinish) or (
isinstance(chunk, str) and "[DONE]" in chunk
):
break
except asyncio.TimeoutError:
yield ": keepalive\n\n"
except Exception:
logger.exception("Bot chat stream error for session %s", session_id)
yield 'data: {"type": "error", "content": "Stream error"}\n\n'
yield "data: [DONE]\n\n"
finally:
if subscriber_queue is not None:
await stream_registry.unsubscribe_from_session(
session_id=session_id,
subscriber_queue=subscriber_queue,
)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Session-Id": session_id,
},
)

View File

@@ -1,107 +0,0 @@
"""Pydantic models for the platform bot linking API."""
from datetime import datetime
from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field
class Platform(str, Enum):
"""Supported platform types (mirrors Prisma PlatformType)."""
DISCORD = "DISCORD"
TELEGRAM = "TELEGRAM"
SLACK = "SLACK"
TEAMS = "TEAMS"
WHATSAPP = "WHATSAPP"
GITHUB = "GITHUB"
LINEAR = "LINEAR"
# ── Request Models ─────────────────────────────────────────────────────
class CreateLinkTokenRequest(BaseModel):
"""Request from the bot service to create a linking token."""
platform: Platform = Field(description="Platform name")
platform_user_id: str = Field(
description="The user's ID on the platform",
min_length=1,
max_length=255,
)
platform_username: str | None = Field(
default=None,
description="Display name (best effort)",
max_length=255,
)
channel_id: str | None = Field(
default=None,
description="Channel ID for sending confirmation back",
max_length=255,
)
class ResolveRequest(BaseModel):
"""Resolve a platform identity to an AutoGPT user."""
platform: Platform
platform_user_id: str = Field(min_length=1, max_length=255)
class BotChatRequest(BaseModel):
"""Request from the bot to chat as a linked user."""
user_id: str = Field(description="The linked AutoGPT user ID")
message: str = Field(
description="The user's message", min_length=1, max_length=32000
)
session_id: str | None = Field(
default=None,
description="Existing chat session ID. If omitted, a new session is created.",
)
# ── Response Models ────────────────────────────────────────────────────
class LinkTokenResponse(BaseModel):
token: str
expires_at: datetime
link_url: str
class LinkTokenStatusResponse(BaseModel):
status: Literal["pending", "linked", "expired"]
user_id: str | None = None
class ResolveResponse(BaseModel):
linked: bool
user_id: str | None = None
class PlatformLinkInfo(BaseModel):
id: str
platform: str
platform_user_id: str
platform_username: str | None
linked_at: datetime
class ConfirmLinkResponse(BaseModel):
success: bool
platform: str
platform_user_id: str
platform_username: str | None
class DeleteLinkResponse(BaseModel):
success: bool
class BotChatSessionResponse(BaseModel):
"""Returned when creating a new session via the bot proxy."""
session_id: str

View File

@@ -1,340 +0,0 @@
"""
Platform Bot Linking API routes.
Enables linking external chat platform identities (Discord, Telegram, Slack, etc.)
to AutoGPT user accounts. Used by the multi-platform CoPilot bot.
Flow:
1. Bot calls POST /api/platform-linking/tokens to create a link token
for an unlinked platform user.
2. Bot sends the user a link: {frontend}/link/{token}
3. User clicks the link, logs in to AutoGPT, and the frontend calls
POST /api/platform-linking/tokens/{token}/confirm to complete the link.
4. Bot can poll GET /api/platform-linking/tokens/{token}/status or just
check on next message via GET /api/platform-linking/resolve.
"""
import logging
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Path, Security
from prisma.models import PlatformLink, PlatformLinkToken
from .auth import check_bot_api_key, get_bot_api_key
from .models import (
ConfirmLinkResponse,
CreateLinkTokenRequest,
DeleteLinkResponse,
LinkTokenResponse,
LinkTokenStatusResponse,
PlatformLinkInfo,
ResolveRequest,
ResolveResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter()
LINK_TOKEN_EXPIRY_MINUTES = 30
# Path parameter with validation for link tokens
TokenPath = Annotated[
str,
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
]
# ── Bot-facing endpoints (API key auth) ───────────────────────────────
@router.post(
"/tokens",
response_model=LinkTokenResponse,
summary="Create a link token for an unlinked platform user",
)
async def create_link_token(
request: CreateLinkTokenRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> LinkTokenResponse:
"""
Called by the bot service when it encounters an unlinked user.
Generates a one-time token the user can use to link their account.
"""
check_bot_api_key(x_bot_api_key)
platform = request.platform.value
# Check if already linked
existing = await PlatformLink.prisma().find_first(
where={
"platform": platform,
"platformUserId": request.platform_user_id,
}
)
if existing:
raise HTTPException(
status_code=409,
detail="This platform account is already linked.",
)
# Invalidate any existing pending tokens for this user
await PlatformLinkToken.prisma().update_many(
where={
"platform": platform,
"platformUserId": request.platform_user_id,
"usedAt": None,
},
data={"usedAt": datetime.now(timezone.utc)},
)
# Generate token
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(
minutes=LINK_TOKEN_EXPIRY_MINUTES
)
await PlatformLinkToken.prisma().create(
data={
"token": token,
"platform": platform,
"platformUserId": request.platform_user_id,
"platformUsername": request.platform_username,
"channelId": request.channel_id,
"expiresAt": expires_at,
}
)
logger.info(
"Created link token for %s (expires %s)",
platform,
expires_at.isoformat(),
)
link_base_url = os.getenv(
"PLATFORM_LINK_BASE_URL", "https://platform.agpt.co/link"
)
link_url = f"{link_base_url}/{token}"
return LinkTokenResponse(
token=token,
expires_at=expires_at,
link_url=link_url,
)
@router.get(
"/tokens/{token}/status",
response_model=LinkTokenStatusResponse,
summary="Check if a link token has been consumed",
)
async def get_link_token_status(
token: TokenPath,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> LinkTokenStatusResponse:
"""
Called by the bot service to check if a user has completed linking.
"""
check_bot_api_key(x_bot_api_key)
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise HTTPException(status_code=404, detail="Token not found")
if link_token.usedAt is not None:
# Token was used — find the linked account
link = await PlatformLink.prisma().find_first(
where={
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
}
)
return LinkTokenStatusResponse(
status="linked",
user_id=link.userId if link else None,
)
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
return LinkTokenStatusResponse(status="expired")
return LinkTokenStatusResponse(status="pending")
@router.post(
"/resolve",
response_model=ResolveResponse,
summary="Resolve a platform identity to an AutoGPT user",
)
async def resolve_platform_user(
request: ResolveRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> ResolveResponse:
"""
Called by the bot service on every incoming message to check if
the platform user has a linked AutoGPT account.
"""
check_bot_api_key(x_bot_api_key)
link = await PlatformLink.prisma().find_first(
where={
"platform": request.platform.value,
"platformUserId": request.platform_user_id,
}
)
if not link:
return ResolveResponse(linked=False)
return ResolveResponse(linked=True, user_id=link.userId)
# ── User-facing endpoints (JWT auth) ──────────────────────────────────
@router.post(
"/tokens/{token}/confirm",
response_model=ConfirmLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a link token (user must be authenticated)",
)
async def confirm_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmLinkResponse:
"""
Called by the frontend when the user clicks the link and is logged in.
Consumes the token and creates the platform link.
Uses atomic update_many to prevent race conditions on double-click.
"""
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise HTTPException(status_code=404, detail="Token not found.")
if link_token.usedAt is not None:
raise HTTPException(status_code=410, detail="This link has already been used.")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise HTTPException(status_code=410, detail="This link has expired.")
# Atomically mark token as used (only if still unused)
updated = await PlatformLinkToken.prisma().update_many(
where={"token": token, "usedAt": None},
data={"usedAt": datetime.now(timezone.utc)},
)
if updated == 0:
raise HTTPException(status_code=410, detail="This link has already been used.")
# Check if this platform identity is already linked
existing = await PlatformLink.prisma().find_first(
where={
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
}
)
if existing:
detail = (
"This platform account is already linked to your account."
if existing.userId == user_id
else "This platform account is already linked to another user."
)
raise HTTPException(status_code=409, detail=detail)
# Create the link — catch unique constraint race condition
try:
await PlatformLink.prisma().create(
data={
"userId": user_id,
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
"platformUsername": link_token.platformUsername,
}
)
except Exception as exc:
if "unique" in str(exc).lower():
raise HTTPException(
status_code=409,
detail="This platform account was just linked by another request.",
) from exc
raise
logger.info(
"Linked %s:%s to user ...%s",
link_token.platform,
link_token.platformUserId,
user_id[-8:],
)
return ConfirmLinkResponse(
success=True,
platform=link_token.platform,
platform_user_id=link_token.platformUserId,
platform_username=link_token.platformUsername,
)
@router.get(
"/links",
response_model=list[PlatformLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all platform links for the authenticated user",
)
async def list_my_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformLinkInfo]:
"""Returns all platform identities linked to the current user's account."""
links = await PlatformLink.prisma().find_many(
where={"userId": user_id},
order={"linkedAt": "desc"},
)
return [
PlatformLinkInfo(
id=link.id,
platform=link.platform,
platform_user_id=link.platformUserId,
platform_username=link.platformUsername,
linked_at=link.linkedAt,
)
for link in links
]
@router.delete(
"/links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a platform identity",
)
async def delete_link(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
"""
Removes a platform link. The user will need to re-link if they
want to use the bot on that platform again.
"""
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
if not link:
raise HTTPException(status_code=404, detail="Link not found.")
if link.userId != user_id:
raise HTTPException(status_code=403, detail="Not your link.")
await PlatformLink.prisma().delete(where={"id": link_id})
logger.info(
"Unlinked %s:%s from user ...%s",
link.platform,
link.platformUserId,
user_id[-8:],
)
return DeleteLinkResponse(success=True)

View File

@@ -1,138 +0,0 @@
"""Tests for platform bot linking API routes."""
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from backend.api.features.platform_linking.auth import check_bot_api_key
from backend.api.features.platform_linking.models import (
ConfirmLinkResponse,
CreateLinkTokenRequest,
DeleteLinkResponse,
LinkTokenStatusResponse,
Platform,
ResolveRequest,
)
class TestPlatformEnum:
def test_all_platforms_exist(self):
assert Platform.DISCORD.value == "DISCORD"
assert Platform.TELEGRAM.value == "TELEGRAM"
assert Platform.SLACK.value == "SLACK"
assert Platform.TEAMS.value == "TEAMS"
assert Platform.WHATSAPP.value == "WHATSAPP"
assert Platform.GITHUB.value == "GITHUB"
assert Platform.LINEAR.value == "LINEAR"
class TestBotApiKeyAuth:
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
@patch("backend.api.features.platform_linking.auth.Settings")
def test_no_key_configured_allows_when_auth_disabled(self, mock_settings_cls):
mock_settings_cls.return_value.config.enable_auth = False
check_bot_api_key(None)
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
@patch("backend.api.features.platform_linking.auth.Settings")
def test_no_key_configured_rejects_when_auth_enabled(self, mock_settings_cls):
mock_settings_cls.return_value.config.enable_auth = True
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key(None)
assert exc_info.value.status_code == 503
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_valid_key(self):
check_bot_api_key("secret123")
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_invalid_key_rejected(self):
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key("wrong")
assert exc_info.value.status_code == 401
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_missing_key_rejected(self):
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key(None)
assert exc_info.value.status_code == 401
class TestCreateLinkTokenRequest:
def test_valid_request(self):
req = CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="353922987235213313",
)
assert req.platform == Platform.DISCORD
assert req.platform_user_id == "353922987235213313"
def test_empty_platform_user_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="",
)
def test_too_long_platform_user_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="x" * 256,
)
def test_invalid_platform_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest.model_validate(
{"platform": "INVALID", "platform_user_id": "123"}
)
class TestResolveRequest:
def test_valid_request(self):
req = ResolveRequest(
platform=Platform.TELEGRAM,
platform_user_id="123456789",
)
assert req.platform == Platform.TELEGRAM
def test_empty_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
ResolveRequest(
platform=Platform.SLACK,
platform_user_id="",
)
class TestResponseModels:
def test_link_token_status_literal(self):
resp = LinkTokenStatusResponse(status="pending")
assert resp.status == "pending"
resp = LinkTokenStatusResponse(status="linked", user_id="abc")
assert resp.status == "linked"
resp = LinkTokenStatusResponse(status="expired")
assert resp.status == "expired"
def test_confirm_link_response(self):
resp = ConfirmLinkResponse(
success=True,
platform="DISCORD",
platform_user_id="123",
platform_username="testuser",
)
assert resp.success is True
def test_delete_link_response(self):
resp = DeleteLinkResponse(success=True)
assert resp.success is True

View File

@@ -63,12 +63,17 @@ from backend.data.onboarding import (
UserOnboardingUpdate,
complete_onboarding_step,
complete_re_run_agent,
format_onboarding_for_extraction,
get_recommended_agents,
get_user_onboarding,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
)
from backend.data.tally import extract_business_understanding
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
@@ -282,35 +287,33 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding status check."""
class OnboardingProfileRequest(pydantic.BaseModel):
"""Request body for onboarding profile submission."""
is_onboarding_enabled: bool
is_chat_enabled: bool
user_name: str = pydantic.Field(min_length=1, max_length=100)
user_role: str = pydantic.Field(min_length=1, max_length=100)
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding completion check."""
is_completed: bool
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
"/onboarding/completed",
summary="Check if onboarding is completed",
tags=["onboarding", "public"],
response_model=OnboardingStatusResponse,
dependencies=[Security(requires_user)],
)
async def is_onboarding_enabled(
async def is_onboarding_completed(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
user_onboarding = await get_user_onboarding(user_id)
return OnboardingStatusResponse(
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
)
@@ -325,6 +328,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await reset_user_onboarding(user_id)
@v1_router.post(
"/onboarding/profile",
summary="Submit onboarding profile",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def submit_onboarding_profile(
data: OnboardingProfileRequest,
user_id: Annotated[str, Security(get_user_id)],
):
formatted = format_onboarding_for_extraction(
user_name=data.user_name,
user_role=data.user_role,
pain_points=data.pain_points,
)
try:
understanding_input = await extract_business_understanding(formatted)
except Exception:
understanding_input = BusinessUnderstandingInput.model_construct()
# Ensure the direct fields are set even if LLM missed them
understanding_input.user_name = data.user_name
understanding_input.user_role = data.user_role
if not understanding_input.pain_points:
understanding_input.pain_points = data.pain_points
await upsert_business_understanding(user_id, understanding_input)
return {"status": "ok"}
########################################################
##################### Blocks ###########################
########################################################

View File

@@ -30,8 +30,6 @@ import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.chat_proxy
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
@@ -120,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -363,16 +366,6 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.include_router(
backend.api.features.platform_linking.chat_proxy.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.mount("/external-api", external_api)

View File

@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Validate the input data (original or reviewer-modified) once.
# In dry-run mode, credential fields may contain sentinel None values
# that would fail JSON schema required checks. We still validate the
# non-credential fields so blocks that execute for real during dry-run
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
else:
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(

View File

@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
# Check against the nested `inputs` dict, not the top-level node
# data — required fields like "topic" live inside data["inputs"],
# not at data["topic"].
provided = data.get("inputs", {})
return set(required_fields) - set(provided)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(cls.get_input_schema(data), data)
return validate_with_jsonschema(
cls.get_input_schema(data), data.get("inputs", {})
)
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
execution_context=execution_context.model_copy(
update={"parent_execution_id": graph_exec_id},
),
dry_run=execution_context.dry_run,
)
logger = execution_utils.LogMetadata(
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
]:
logger.debug(
f"Execution {log_id} received event {event.event_type} with status {event.status}"
logger.info(
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
f"node={getattr(event, 'node_exec_id', '?')}"
)
continue
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
logger.info(
f"Execution {log_id} graph completed with status {event.status}, "
f"yielded {len(yielded_node_exec_ids)} outputs"
)
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,

View File

@@ -146,6 +146,21 @@ class AutoPilotBlock(Block):
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -232,11 +247,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str) -> str:
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
session = await create_chat_session(user_id)
session = await create_chat_session(user_id, dry_run=dry_run)
return session.session_id
async def execute_copilot(
@@ -367,7 +382,9 @@ class AutoPilotBlock(Block):
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(execution_context.user_id)
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Protocol, runtime_checkable
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
def _make_mime_text(
@@ -100,14 +145,16 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = ", ".join(input_data.cc)
message["cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
message["bcc"] = ", ".join(input_data.bcc)
message["bcc"] = serialize_email_recipients(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1685,13 +1734,16 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -2,6 +2,8 @@ import copy
from datetime import date, time
from typing import Any, Optional
from pydantic import AliasChoices, Field
from backend.blocks._base import (
Block,
BlockCategory,
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that relies on placeholder_values to present a dropdown.
A specialized text input block that presents a dropdown selector
restricted to a fixed set of values.
"""
class Input(AgentInputBlock.Input):
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
# Use Field() directly (not SchemaField) to pass validation_alias,
# which handles backward compat for legacy "placeholder_values" across
# all construction paths (model_construct, __init__, model_validate).
options: list = Field(
default_factory=list,
advanced=False,
title="Dropdown Options",
description=(
"If provided, renders the input as a dropdown selector "
"restricted to these values. Leave empty for free-text input."
),
validation_alias=AliasChoices("options", "placeholder_values"),
json_schema_extra={"advanced": False, "secret": False},
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.placeholder_values:
if possible_values := self.options:
schema["enum"] = possible_values
return schema
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
{
"value": "Option A",
"name": "dropdown_1",
"placeholder_values": ["Option A", "Option B", "Option C"],
"options": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"placeholder_values": ["Option A", "Option B", "Option C"],
"options": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],

View File

@@ -724,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -739,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -972,6 +978,8 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1031,12 +1039,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
raise ValueError(f"OpenRouter returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1073,12 +1077,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
raise ValueError(f"Llama API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1108,6 +1108,8 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1144,6 +1146,9 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -2011,6 +2016,19 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

View File

@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
default={},
hidden=True,
)
tool_description: str = SchemaField(
description="Description of the selected MCP tool. "
"Populated automatically when a tool is selected.",
default="",
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,323 @@
import asyncio
from typing import Any, Literal
from pydantic import SecretStr
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.sql_query_helpers import (
_DATABASE_TYPE_DEFAULT_PORT,
_DATABASE_TYPE_TO_DRIVER,
DatabaseType,
_execute_query,
_sanitize_error,
_validate_query_is_read_only,
_validate_single_statement,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.request import resolve_and_check_blocked
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="database",
username=SecretStr("test_user"),
password=SecretStr("test_pass"),
title="Mock Database credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
DatabaseCredentials = UserPasswordCredentials
DatabaseCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DATABASE],
Literal["user_password"],
]
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
return CredentialsField(
description="Database username and password",
)
class SQLQueryBlock(Block):
class Input(BlockSchemaInput):
database_type: DatabaseType = SchemaField(
default=DatabaseType.POSTGRES,
description="Database engine",
advanced=False,
)
host: SecretStr = SchemaField(
description=(
"Database hostname or IP address. "
"Treated as a secret to avoid leaking infrastructure details. "
"Private/internal IPs are blocked (SSRF protection)."
),
placeholder="db.example.com",
secret=True,
)
port: int | None = SchemaField(
default=None,
description=(
"Database port (leave empty for default: "
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
),
ge=1,
le=65535,
)
database: str = SchemaField(
description="Name of the database to connect to",
placeholder="my_database",
)
query: str = SchemaField(
description="SQL query to execute",
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
)
read_only: bool = SchemaField(
default=True,
description=(
"When enabled (default), only SELECT queries are allowed "
"and the database session is set to read-only mode. "
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
),
)
timeout: int = SchemaField(
default=30,
description="Query timeout in seconds (max 120)",
ge=1,
le=120,
)
max_rows: int = SchemaField(
default=1000,
description="Maximum number of rows to return (max 10000)",
ge=1,
le=10000,
)
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
class Output(BlockSchemaOutput):
results: list[dict[str, Any]] = SchemaField(
description="Query results as a list of row dictionaries"
)
columns: list[str] = SchemaField(
description="Column names from the query result"
)
row_count: int = SchemaField(description="Number of rows returned")
truncated: bool = SchemaField(
description=(
"True when the result set was capped by max_rows, "
"indicating additional rows exist in the database"
)
)
affected_rows: int = SchemaField(
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
)
error: str = SchemaField(description="Error message if the query failed")
def __init__(self):
super().__init__(
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
description=(
"Execute a SQL query. Read-only by default for safety "
"-- disable to allow write operations. "
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
),
categories={BlockCategory.DATA},
input_schema=SQLQueryBlock.Input,
output_schema=SQLQueryBlock.Output,
test_input={
"query": "SELECT 1 AS test_col",
"database_type": DatabaseType.POSTGRES,
"host": "localhost",
"database": "test_db",
"timeout": 30,
"max_rows": 1000,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("results", [{"test_col": 1}]),
("columns", ["test_col"]),
("row_count", 1),
("truncated", False),
],
test_mock={
"execute_query": lambda *_args, **_kwargs: (
[{"test_col": 1}],
["test_col"],
-1,
False,
),
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
},
)
@staticmethod
async def check_host_allowed(host: str) -> list[str]:
"""Validate that the given host is not a private/blocked address.
Returns the list of resolved IP addresses so the caller can pin the
connection to the validated IP (preventing DNS rebinding / TOCTOU).
Raises ValueError or OSError if the host is blocked.
Extracted as a method so it can be mocked during block tests.
"""
return await resolve_and_check_blocked(host)
@staticmethod
def execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Delegates to ``_execute_query`` in ``sql_query_helpers``.
Extracted as a method so it can be mocked during block tests.
"""
return _execute_query(
connection_url=connection_url,
query=query,
timeout=timeout,
max_rows=max_rows,
read_only=read_only,
database_type=database_type,
)
async def run(
self,
input_data: Input,
*,
credentials: DatabaseCredentials,
**_kwargs: Any,
) -> BlockOutput:
# Validate query structure and read-only constraints.
error = self._validate_query(input_data)
if error:
yield "error", error
return
# Validate host and resolve for SSRF protection.
host, pinned_host, error = await self._resolve_host(input_data)
if error:
yield "error", error
return
# Build connection URL and execute.
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
username = credentials.username.get_secret_value()
connection_url = URL.create(
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
username=username,
password=credentials.password.get_secret_value(),
host=pinned_host,
port=port,
database=input_data.database,
)
conn_str = connection_url.render_as_string(hide_password=True)
db_name = input_data.database
def _sanitize(err: Exception) -> str:
return _sanitize_error(
str(err).strip(),
conn_str,
host=pinned_host,
original_host=host,
username=username,
port=port,
database=db_name,
)
try:
results, columns, affected, truncated = await asyncio.to_thread(
self.execute_query,
connection_url=connection_url,
query=input_data.query,
timeout=input_data.timeout,
max_rows=input_data.max_rows,
read_only=input_data.read_only,
database_type=input_data.database_type,
)
yield "results", results
yield "columns", columns
yield "row_count", len(results)
yield "truncated", truncated
if affected >= 0:
yield "affected_rows", affected
except OperationalError as e:
yield (
"error",
self._classify_operational_error(
_sanitize(e),
input_data.timeout,
),
)
except ProgrammingError as e:
yield "error", f"SQL error: {_sanitize(e)}"
except DBAPIError as e:
yield "error", f"Database error: {_sanitize(e)}"
except ModuleNotFoundError:
yield (
"error",
(
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
),
)
@staticmethod
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
"""Validate query structure and read-only constraints."""
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
return stmt_error
assert parsed_stmt is not None
if input_data.read_only:
return _validate_query_is_read_only(parsed_stmt)
return None
async def _resolve_host(
self, input_data: "SQLQueryBlock.Input"
) -> tuple[str, str, str | None]:
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
host = input_data.host.get_secret_value().strip()
if not host:
return "", "", "Database host is required."
if host.startswith("/"):
return host, "", "Unix socket connections are not allowed."
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
return host, "", f"Blocked host: {str(e).strip()}"
return host, resolved_ips[0], None
@staticmethod
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
"""Classify an already-sanitized OperationalError for user display."""
lower = sanitized_msg.lower()
if "timeout" in lower or "cancel" in lower:
return f"Query timed out after {timeout}s."
if "connect" in lower:
return f"Failed to connect to database: {sanitized_msg}"
return f"Database error: {sanitized_msg}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
import re
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
import sqlparse
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
"LOAD",
# MySQL REPLACE is INSERT-or-UPDATE; data modification
"REPLACE",
# ANSI MERGE (UPSERT) modifies data
"MERGE",
# MSSQL BULK INSERT loads external files into tables
"BULK",
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
"EXEC",
}
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
# login_timeout). This bounds how long the driver waits to establish a TCP
# connection to the database server. It is separate from the per-statement
# timeout configured via SET commands inside _configure_session().
_CONNECT_TIMEOUT_SECONDS = 10
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
database: str = "",
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
- The database name
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
# Replace the database name to avoid leaking internal infrastructure names.
# Use word-boundary regex to prevent mangling when the database name is a
# common substring (e.g. "test", "data", "on").
if database:
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
return [
token.normalized.upper()
for token in parsed.flatten()
if token.ttype
in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
)
]
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
"""Check if a statement contains a disallowed ``INTO`` clause.
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
a query result into a session-scoped user variable. All other forms of
``INTO`` are data-modifying or file-writing and must be blocked:
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL creates a table)
* ``SELECT ... INTO OUTFILE`` (MySQL writes to the filesystem)
* ``SELECT ... INTO DUMPFILE`` (MySQL writes to the filesystem)
* ``INSERT INTO ...`` (already blocked by INSERT being in the
disallowed set, but we reject INTO as well for defense-in-depth)
Returns ``True`` if the statement contains a disallowed ``INTO``.
"""
flat = list(stmt.flatten())
for i, token in enumerate(flat):
if not (
token.ttype in (sqlparse.tokens.Keyword,)
and token.normalized.upper() == "INTO"
):
continue
# Look at the first non-whitespace token after INTO.
j = i + 1
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
j += 1
if j >= len(flat):
# INTO at the very end malformed, block it.
return True
next_token = flat[j]
# MySQL user variable: either a single Name starting with "@"
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
"@"
):
continue
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
continue
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
return True
return False
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
for kw in _extract_keyword_tokens(stmt):
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
if _has_disallowed_into(stmt):
return "Disallowed SQL keyword: INTO"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# NaN / Infinity are not valid JSON numbers; serialize as strings.
if value.is_nan() or value.is_infinite():
return str(value)
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, memoryview):
return bytes(value).hex()
if isinstance(value, bytes):
return value.hex()
return value
def _configure_session(
conn: Any,
dialect_name: str,
timeout_ms: str,
read_only: bool,
) -> None:
"""Set session-level timeout and read-only mode for the given dialect.
Timeout limitations by database:
* **PostgreSQL** ``statement_timeout`` reliably cancels any running
statement (SELECT or DML) after the configured duration.
* **MySQL** ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
this hint; they rely on the server's ``wait_timeout`` /
``interactive_timeout`` instead. There is no session-level setting in
MySQL that reliably cancels long-running writes.
* **MSSQL** ``SET LOCK_TIMEOUT`` only limits how long the server waits
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
joins) that do not block on locks will *not* be cancelled. MSSQL has
no session-level ``statement_timeout`` equivalent; the closest
mechanism is Resource Governor (requires sysadmin configuration) or
``CONTEXT_INFO``-based external monitoring.
Note: SQLite is not supported by this block. The ``_configure_session``
function is a no-op for unrecognised dialect names, so an SQLite engine
would skip all SET commands silently. The block's ``DatabaseType`` enum
intentionally excludes SQLite.
"""
if dialect_name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif dialect_name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
# See docstring above for full limitations.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif dialect_name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
# CPU-bound queries without lock contention are NOT cancelled.
# See docstring above for full limitations.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
def _run_in_transaction(
conn: Any,
dialect_name: str,
query: str,
max_rows: int,
read_only: bool,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a query inside an explicit transaction, returning results.
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
indicating that additional rows may exist in the result set.
"""
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
conn.execute(text(begin_stmt))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
truncated = len(rows) == max_rows
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
except Exception:
try:
conn.execute(text("ROLLBACK"))
except Exception:
pass
raise
else:
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
return results, columns, affected, truncated
def _execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
timeout_key = (
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
)
engine = create_engine(
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
)
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Compute timeout in milliseconds. The value is Pydantic-validated
# (ge=1, le=120), but we use int() as defense-in-depth.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
return _run_in_transaction(
conn, engine.dialect.name, query, max_rows, read_only
)
finally:
engine.dispose()

View File

@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
options = ["Option A", "Option B"]
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
using the canonical 'options' field name."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=options
name="choice", value=None, options=opts
)
schema = instance.generate_schema()
assert schema.get("enum") == options
assert schema.get("enum") == opts
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
"""Verify backward compat: passing legacy 'placeholder_values' to
AgentDropdownInputBlock still produces enum via model_construct remap."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=opts
)
schema = instance.generate_schema()
assert (
schema.get("enum") == opts
), "Legacy placeholder_values should be remapped to options"
def test_generate_schema_integration_legacy_placeholder_values():
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks."""
— verifies enum IS produced for dropdown blocks using canonical field name."""
dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
"options": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
using legacy 'placeholder_values' — verifies backward compat produces enum."""
legacy_dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Legacy placeholder_values should still produce enum via model_construct remap"
def test_dropdown_input_block_init_legacy_placeholder_values():
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_validate(
{"name": "choice", "value": None, "placeholder_values": opts}
)
assert (
instance.options == opts
), "Legacy placeholder_values should be remapped to options via model_validate"
schema = instance.generate_schema()
assert schema.get("enum") == opts

View File

@@ -488,6 +488,154 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""

View File

@@ -0,0 +1,87 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -0,0 +1,202 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

File diff suppressed because it is too large Load Diff

View File

@@ -9,12 +9,16 @@ shared tool registry as the SDK path.
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.context import set_execution_context
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -48,7 +52,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallResult,
tool_call_loop,
)
logger = logging.getLogger(__name__)
@@ -59,6 +73,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
can be module-level functions instead of deeply nested closures.
"""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
async def _baseline_llm_caller(
messages: list[dict[str, Any]],
tools: Sequence[Any],
*,
state: _BaselineStreamState,
) -> LLMLoopResponse:
"""Stream an OpenAI-compatible response and collect results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
round_text = ""
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
stream_options={"include_usage": True},
)
else:
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
)
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += delta.content
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=delta.content)
)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
LLMToolCall(
id=tc["id"],
name=tc["name"],
arguments=tc["arguments"] or "{}",
)
for tc in tool_calls_by_index.values()
]
return LLMLoopResponse(
response_text=round_text or None,
tool_calls=llm_tool_calls,
raw_response=None, # Not needed for baseline conversation updater
prompt_tokens=0, # Tracked via state accumulators
completion_tokens=0,
)
async def _baseline_tool_executor(
tool_call: LLMToolCall,
tools: Sequence[Any],
*,
state: _BaselineStreamState,
user_id: str | None,
session: ChatSession,
) -> ToolCallResult:
"""Execute a tool via the copilot tool registry.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
tool_call_id = tool_call.id
tool_name = tool_call.name
raw_args = tool_call.arguments or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=parse_error,
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
)
state.pending_events.append(
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
)
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=tool_output,
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=error_output,
is_error=True,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": tc.arguments},
}
for tc in response.tool_calls
]
messages.append(assistant_msg)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
@@ -203,6 +458,9 @@ async def stream_chat_completion_baseline(
tools = get_available_tools()
# Propagate execution context so tool handlers can read session-level flags.
set_execution_context(user_id, session)
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
@@ -219,191 +477,32 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
_bound_tool_executor = partial(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
@@ -418,11 +517,28 @@ async def stream_chat_completion_baseline(
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
# Close any open text block. The llm_caller's finally block
# already appended StreamFinishStep to pending_events, so we must
# insert StreamTextEnd *before* StreamFinishStep to preserve the
# protocol ordering:
# StreamStartStep -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Appending (or yielding directly) would place it after
# StreamFinishStep, violating the protocol.
if state.text_started:
# Find the last StreamFinishStep and insert before it.
insert_pos = len(state.pending_events)
for i in range(len(state.pending_events) - 1, -1, -1):
if isinstance(state.pending_events[i], StreamFinishStep):
insert_pos = i
break
state.pending_events.insert(
insert_pos, StreamTextEnd(id=state.text_block_id)
)
# Drain pending events in correct order
for evt in state.pending_events:
yield evt
state.pending_events.clear()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
@@ -442,26 +558,21 @@ async def stream_chat_completion_baseline(
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
state.turn_prompt_tokens == 0
and state.turn_completion_tokens == 0
and not (_stream_error and not state.assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
state.turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
state.turn_prompt_tokens,
state.turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
@@ -471,15 +582,15 @@ async def stream_chat_completion_baseline(
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
if state.assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
ChatMessage(role="assistant", content=state.assistant_text)
)
try:
await upsert_chat_session(session)
@@ -491,11 +602,11 @@ async def stream_chat_completion_baseline(
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await create_chat_session(test_user_id, dry_run=False)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -20,6 +20,10 @@ class ChatConfig(BaseSettings):
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
simulation_model: str = Field(
default="google/gemini-2.5-flash",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
@@ -178,7 +182,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``_build_sdk_env``.
present — mirrors the fallback logic in ``build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
or ``tool-outputs/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
"""
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and tool-results/.
# the project dir and the tool directory.
# Accept both "tool-results" (SDK's persisted outputs) and
# "tool-outputs" (the model sometimes confuses workspace paths
# with filesystem paths and generates this variant).
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results", <file>, ...]
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] == "tool-results"
and parts[1] in ("tool-results", "tool-outputs")
):
return True

View File

@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_outputs_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
encoded = "test-encoded-dir"
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
)
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(

View File

@@ -18,7 +18,13 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
from .model import (
ChatMessage,
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
invalidate_session_cache,
)
logger = logging.getLogger(__name__)
@@ -35,6 +41,7 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
async def create_chat_session(
session_id: str,
user_id: str,
metadata: ChatSessionMetadata | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -43,6 +50,7 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -57,7 +65,12 @@ async def update_chat_session(
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
"""Update a chat session's metadata."""
"""Update a chat session's mutable fields.
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
it is set once at creation time and treated as immutable for the lifetime
of the session.
"""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:

View File

@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
# GitHub user identity caches (keyed by user_id only, not provider tuple).
# Declared here so invalidate_user_provider_cache() can reference them.
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
)
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
For GitHub specifically, also clears the git-identity caches so that
``get_github_user_git_identity()`` re-fetches the user's profile on
the next call instead of serving stale identity data.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
if provider == "github":
_gh_identity_cache.pop(user_id, None)
_gh_identity_null_cache.pop(user_id, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
refresh_failed = False
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
# Do NOT fall back to the stale token — it is likely expired
# or revoked. Returning None forces the caller to re-auth,
# preventing the LLM from receiving a non-functional token.
refresh_failed = True
continue
_token_cache[cache_key] = token
return token
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
# Only cache "not connected" when the user truly has no credentials for this
# provider. If we had OAuth credentials but refresh failed (e.g. transient
# network error, event-loop mismatch), do NOT cache the negative result —
# the next call should retry the refresh instead of being blocked for 60 s.
if not refresh_failed:
_null_cache[cache_key] = True
return None
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
for var in var_names:
env[var] = token
return env
# ---------------------------------------------------------------------------
# GitHub user identity (for git committer env vars)
# ---------------------------------------------------------------------------
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
"""Fetch the GitHub user's name and email for git committer env vars.
Uses the ``/user`` GitHub API endpoint with the user's stored token.
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
connected GitHub account. Returns ``None`` otherwise.
Results are cached for 10 minutes; "not connected" results are cached for
60 s (same as null-token cache).
"""
if user_id in _gh_identity_null_cache:
return None
if cached := _gh_identity_cache.get(user_id):
return cached
token = await get_provider_token(user_id, "github")
if not token:
_gh_identity_null_cache[user_id] = True
return None
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.github.com/user",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status != 200:
logger.warning(
"[git-identity] GitHub /user returned %s for user %s",
resp.status,
user_id,
)
return None
data = await resp.json()
except Exception as exc:
logger.warning(
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
user_id,
exc,
)
return None
name = data.get("name") or data.get("login") or "AutoGPT User"
# GitHub may return email=null if the user has set their email to private.
# Fall back to the noreply address GitHub generates for every account.
email = data.get("email")
if not email:
gh_id = data.get("id", "")
login = data.get("login", "user")
email = f"{gh_id}+{login}@users.noreply.github.com"
identity = {
"GIT_AUTHOR_NAME": name,
"GIT_AUTHOR_EMAIL": email,
"GIT_COMMITTER_NAME": name,
"GIT_COMMITTER_EMAIL": email,
}
_gh_identity_cache[user_id] = identity
return identity

View File

@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_gh_identity_cache,
_gh_identity_null_cache,
_null_cache,
_token_cache,
get_integration_env_vars,
@@ -49,9 +51,13 @@ def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
class TestInvalidateUserProviderCache:
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
def test_clears_gh_identity_cache_for_github_provider(self):
"""When provider is 'github', identity caches must also be cleared."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Old Name",
"GIT_AUTHOR_EMAIL": "old@example.com",
"GIT_COMMITTER_NAME": "Old Name",
"GIT_COMMITTER_EMAIL": "old@example.com",
}
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_cache
def test_clears_gh_identity_null_cache_for_github_provider(self):
"""When provider is 'github', the identity null-cache must also be cleared."""
_gh_identity_null_cache[_USER] = True
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_null_cache
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
"""When provider is NOT 'github', identity caches must be left alone."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Some Name",
"GIT_AUTHOR_EMAIL": "some@example.com",
"GIT_COMMITTER_NAME": "Some Name",
"GIT_COMMITTER_EMAIL": "some@example.com",
}
invalidate_user_provider_cache(_USER, "some-other-provider")
assert _USER in _gh_identity_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
@@ -129,8 +163,15 @@ class TestGetProviderToken:
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_returns_none(self):
"""On refresh failure, return None instead of caching a stale token."""
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
"""On refresh failure, return None but do NOT cache in null_cache.
The user has credentials — they just couldn't be refreshed right now
(e.g. transient network error or event-loop mismatch in the copilot
executor). Caching a negative result would block all credential
lookups for 60 s even though the creds exist and may refresh fine
on the next attempt.
"""
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
@@ -141,6 +182,8 @@ class TestGetProviderToken:
# Stale tokens must NOT be returned — forces re-auth.
assert result is None
# Must NOT cache negative result when refresh failed — next call retries.
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
@@ -176,6 +219,96 @@ class TestGetProviderToken:
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestThreadSafetyLocks:
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
'Future attached to a different loop' when copilot workers accessed
credentials from different event loops."""
@pytest.mark.asyncio(loop_scope="session")
async def test_store_locks_returns_per_thread_instance(self):
"""IntegrationCredentialsStore.locks() must return different instances
for different threads (via @thread_cached)."""
import asyncio
import concurrent.futures
from backend.integrations.credentials_store import IntegrationCredentialsStore
store = IntegrationCredentialsStore()
async def get_locks_id():
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await store.locks()
return id(locks)
# Get locks from main thread
main_id = await get_locks_id()
# Get locks from a worker thread
def run_in_thread():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(get_locks_id())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
worker_id = await asyncio.get_event_loop().run_in_executor(
pool, run_in_thread
)
assert main_id != worker_id, (
"Store.locks() returned the same instance across threads. "
"This would cause 'Future attached to a different loop' errors."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_manager_delegates_to_store_locks(self):
"""IntegrationCredentialsManager.locks() should delegate to store."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await manager.locks()
# Should have gotten it from the store
assert locks is not None
class TestRefreshUnlockedPath:
"""Bug reproduction: copilot worker threads need lock-free refresh because
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_if_needed_lock_false_skips_redis(self):
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
creds = _make_oauth2_creds()
mock_handler = MagicMock()
mock_handler.needs_refresh = MagicMock(return_value=False)
with patch(
"backend.integrations.creds_manager._get_provider_oauth_handler",
new_callable=AsyncMock,
return_value=mock_handler,
):
result = await manager.refresh_if_needed(_USER, creds, lock=False)
# Should return credentials without touching locks
assert result.id == creds.id
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):

View File

@@ -46,6 +46,16 @@ def _get_session_cache_key(session_id: str) -> str:
# ===================== Chat data models ===================== #
class ChatSessionMetadata(BaseModel):
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
Add new session-level flags here instead of adding DB columns —
no migration required for new fields as long as a default is provided.
"""
dry_run: bool = False
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -90,6 +100,12 @@ class ChatSessionInfo(BaseModel):
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
metadata: ChatSessionMetadata = ChatSessionMetadata()
@property
def dry_run(self) -> bool:
"""Convenience accessor for ``metadata.dry_run``."""
return self.metadata.dry_run
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -103,6 +119,10 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Parse typed metadata from the JSON column.
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
metadata = ChatSessionMetadata.model_validate(raw_metadata)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
@@ -128,6 +148,7 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
metadata=metadata,
)
@@ -135,7 +156,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
def new(cls, user_id: str, *, dry_run: bool) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -145,6 +166,7 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
)
@classmethod
@@ -532,6 +554,7 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
metadata=session.metadata,
)
existing_message_count = 0
@@ -609,21 +632,27 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(user_id: str) -> ChatSession:
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id)
session = ChatSession.new(user_id, dry_run=dry_run)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
metadata=session.metadata,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")

View File

@@ -46,7 +46,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123")
s = ChatSession.new(user_id="abc123", dry_run=False)
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
@@ -57,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages
s = await upsert_chat_session(s)
@@ -75,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages
s = await upsert_chat_session(s)
@@ -90,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
@@ -241,7 +241,7 @@ _raw_tc2 = {
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
@@ -254,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
]
@@ -267,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
@@ -282,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
@@ -300,7 +300,7 @@ def test_add_tool_call_multiple_times():
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
@@ -352,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id)
session = ChatSession.new(user_id=test_user_id, dry_run=False)
for i in range(3):
session.messages.append(
ChatMessage(

View File

@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
ToolName = Literal[
# Platform tools (must match keys in TOOL_REGISTRY)
"add_understanding",
"ask_question",
"bash_exec",
"browser_act",
"browser_navigate",
@@ -102,6 +103,7 @@ ToolName = Literal[
"web_fetch",
"write_workspace_file",
# SDK built-ins
"Agent",
"Edit",
"Glob",
"Grep",

View File

@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
class TestSdkBuiltinToolNames:
def test_expected_builtins_present(self):
expected = {
"Agent",
"Read",
"Write",
"Edit",

View File

@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
- Image: `![chart](workspace://file_id#image/png)`
- Video: `![recording](workspace://file_id#video/mp4)`
### Handling binary/image data in tool outputs — CRITICAL
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
1. **NEVER** try to inline or render the base64 content in your response.
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
3. **Show** the result via the workspace download URL in Markdown: `![image](workspace://file_id#image/png)`.
### Passing large data between tools — CRITICAL
When tool outputs produce large text that you need to feed into another tool:
- **NEVER** copy-paste the full text into the next tool call argument.
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
- This avoids token limits and ensures data integrity.
### File references — @@agptfile:
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
@@ -107,6 +119,13 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
After building the file, reference it with `@@agptfile:` in other tools:
`@@agptfile:/home/user/report.md`
### Web search best practices
- If 3 similar web searches don't return the specific data you need, conclude
it isn't publicly available and work with what you have.
- Prefer fewer, well-targeted searches over many variations of the same query.
- When spawning sub-agents for research, ensure each has a distinct
non-overlapping scope to avoid redundant searches.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
@@ -131,6 +150,11 @@ parent autopilot handles orchestration.
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### SDK tool-result files in E2B
When you `Read` an SDK tool-result file, it is automatically copied into the
sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
@@ -196,19 +220,22 @@ def _build_storage_supplement(
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
These files are on the host filesystem — `bash_exec` runs in the sandbox and
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
where SDK tool-results are NOT stored.
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
To read these files, use `Read` — it reads from the host filesystem.
### Large tool outputs saved to workspace
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
full output is in workspace storage (NOT on the local filesystem). To access it:
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
{_SHARED_TOOL_NOTES}{extra_notes}"""

View File

@@ -0,0 +1,28 @@
"""Tests for agent generation guide — verifies clarification section."""
from pathlib import Path
class TestAgentGenerationGuideContainsClarifySection:
"""The agent generation guide must include the clarification section."""
def test_guide_includes_clarify_section(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
assert "Before or During Building" in content
def test_guide_mentions_find_block_for_clarification(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "find_block" in clarify_section
def test_guide_mentions_ask_question_tool(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "ask_question" in clarify_section

View File

@@ -161,8 +161,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
Fails open: returns False if Redis is unavailable (consistent with
the fail-open design of this module).
Returns False if Redis is unavailable so the caller can handle
compensation (fail-closed for billed operations, unlike the read-only
rate-limit checks which fail-open).
"""
now = datetime.now(UTC)
try:

View File

@@ -70,6 +70,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(0, 12_500_000)),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
@@ -83,6 +87,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -112,6 +120,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -141,6 +153,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -171,6 +187,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -208,6 +228,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -228,6 +252,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -245,6 +273,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -275,6 +307,10 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),

View File

@@ -3,25 +3,55 @@
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Clarifying — Before or During Building
Use `ask_question` whenever the user's intent is ambiguous — whether
that's before starting or midway through the workflow. Common moments:
- **Before building**: output format, delivery channel, data source, or
trigger is unspecified.
- **During block discovery**: multiple blocks could fit and the user
should choose.
- **During JSON generation**: a wiring decision depends on user
preference.
Steps:
1. Call `find_block` (or another discovery tool) to learn what the
platform actually supports for the ambiguous dimension.
2. Call `ask_question` with a concrete question listing the discovered
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
which should the agent use for delivery?").
3. **Wait for the user's answer** before continuing.
**Skip this** when the goal already specifies all dimensions (e.g.
"scrape prices from Amazon and email me daily").
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
returns the full graph structure (nodes + links). **Never edit blindly**
always inspect the current graph first so you know exactly what to change.
Avoid using `include_graph=true` with broad keyword searches, as fetching
multiple graphs at once is expensive and consumes LLM context budget.
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
3. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
- Use block IDs from step 2 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
- When editing, apply targeted changes and preserve unchanged parts
5. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
@@ -74,8 +104,8 @@ These define the agent's interface — what it accepts and what it produces.
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
- Specialized input block that presents a dropdown/select to the user
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
- Optional: `title`, `description`, `value` (default selection)
- Required `input_default` fields: `name` (str)
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
- Output: `result` — the user-selected value at runtime
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
@@ -230,6 +260,17 @@ real API calls, credentials, or credits:
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
the agent JSON and re-save before suggesting a real execution.
**Special block behaviour in dry-run mode:**
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
orchestrator can make LLM calls and agent executors can spawn child graphs.
Their downstream tool blocks and child-graph blocks are still simulated.
Note: real LLM inference calls are made (consuming API quota), even though
platform credits are not charged. Agent-mode iterations are capped at 1 in
dry-run to keep it fast.
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
so the LLM can produce a realistic mock response without connecting to the
MCP server.
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:

View File

@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user")
return ChatSession.new(user_id="test-user", dry_run=False)
# ---------------------------------------------------------------------------

View File

@@ -2,14 +2,30 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from backend.util import json
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
async def _server_noop() -> None:
"""No-op server stub — SDK tests don't need the full backend."""
return None
@pytest_asyncio.fixture(
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
)
async def _graph_cleanup_noop() -> AsyncIterator[None]:
"""No-op graph cleanup stub."""
yield
@pytest.fixture()
def mock_chat_config():
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""

View File

@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
import asyncio
import base64
import hashlib
import itertools
import json
import logging
@@ -28,6 +31,12 @@ from backend.copilot.context import (
logger = logging.getLogger(__name__)
# Default number of lines returned by ``read_file`` when the caller does not
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
# decide whether the model is requesting the full file (and thus whether the
# bridge copy is worthwhile).
_DEFAULT_READ_LIMIT = 2000
async def _check_sandbox_symlink_escape(
sandbox: Any,
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
return sandbox, remote
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
"""Write *content* to *path* inside the sandbox.
The E2B filesystem API (``sandbox.files.write``) and the command API
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
To work around this, writes targeting ``/tmp`` are performed via
``tee`` through the command API, which runs as the sandbox ``user``
and can therefore always overwrite user-owned files.
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
are handled correctly: text is encoded to bytes for the base64 shell
pipe, and raw bytes are passed through without any encoding.
"""
if path == "/tmp" or path.startswith("/tmp/"):
import base64 as _b64
encoded = _b64.b64encode(content.encode()).decode()
raw = content.encode() if isinstance(content, str) else content
encoded = base64.b64encode(raw).decode()
result = await sandbox.commands.run(
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
cwd=E2B_WORKDIR,
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
# stay on the host. When E2B is active, also copy the file into the
# sandbox so bash_exec can access it for further processing.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _read_local(file_path, offset, limit)
if not result.get("isError"):
sandbox = _get_sandbox()
if sandbox is not None:
annotation = await bridge_and_annotate(
sandbox, file_path, offset, limit
)
if annotation:
result["content"][0]["text"] += annotation
return result
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp(output if output else "No matches found.")
# Bridging: copy SDK-internal files into E2B sandbox
# Files larger than this are written to /home/user/ via sandbox.files.write()
# instead of /tmp/ via shell base64, to avoid shell argument length limits
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
# well under the typical Linux ARG_MAX (128 KB).
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
# Files larger than this are skipped entirely to avoid excessive transfer times.
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
async def bridge_to_sandbox(
sandbox: Any, file_path: str, offset: int, limit: int
) -> str | None:
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
When the model reads an SDK-internal file (e.g. tool-results), it often
wants to process the data with bash. Copying the file into the sandbox
under a stable name lets ``bash_exec`` access it without extra steps.
Only copies when offset=0 and limit is large enough to indicate the model
wants the full file. Errors are logged but never propagated.
Returns the sandbox path on success, or ``None`` on skip/failure.
Size handling:
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
(``_sandbox_write``). Kept small to stay within ARG_MAX.
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
``sandbox.files.write()`` to avoid shell argument length limits.
- > 50 MB: skipped entirely with a warning.
The sandbox filename is prefixed with a short hash of the full source
path to avoid collisions when different source files share the same
basename (e.g. multiple ``result.json`` files).
"""
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
return None
try:
expanded = os.path.realpath(os.path.expanduser(file_path))
basename = os.path.basename(expanded)
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
unique_name = f"{source_id}-{basename}"
file_size = os.path.getsize(expanded)
if file_size > _BRIDGE_SKIP_BYTES:
logger.warning(
"[E2B] Skipping bridge for large file (%d bytes): %s",
file_size,
basename,
)
return None
def _read_bytes() -> bytes:
with open(expanded, "rb") as fh:
return fh.read()
raw_content = await asyncio.to_thread(_read_bytes)
try:
text_content: str | None = raw_content.decode("utf-8")
except UnicodeDecodeError:
text_content = None
data: str | bytes = text_content if text_content is not None else raw_content
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
sandbox_path = f"/tmp/{unique_name}"
await _sandbox_write(sandbox, sandbox_path, data)
else:
sandbox_path = f"/home/user/{unique_name}"
await sandbox.files.write(sandbox_path, data)
logger.info(
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
)
return sandbox_path
except Exception:
logger.warning(
"[E2B] Failed to bridge SDK file to sandbox: %s",
file_path,
exc_info=True,
)
return None
async def bridge_and_annotate(
sandbox: Any, file_path: str, offset: int, limit: int
) -> str | None:
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
callers don't need to duplicate the pattern. Returns a string like
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
``None`` if bridging was skipped or failed.
"""
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
if sandbox_path is None:
return None
return f"\n[Sandbox copy available at {sandbox_path}]"
# Local read (for SDK-internal paths)

View File

@@ -3,6 +3,7 @@
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import hashlib
import os
import shutil
from types import SimpleNamespace
@@ -13,12 +14,26 @@ import pytest
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
from .e2b_file_tools import (
_BRIDGE_SHELL_MAX_BYTES,
_BRIDGE_SKIP_BYTES,
_DEFAULT_READ_LIMIT,
_check_sandbox_symlink_escape,
_read_local,
_sandbox_write,
bridge_and_annotate,
bridge_to_sandbox,
resolve_sandbox_path,
)
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
"""Compute the expected sandbox path for a bridged file."""
expanded = os.path.realpath(os.path.expanduser(file_path))
basename = os.path.basename(expanded)
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
return f"{prefix}/{source_id}-{basename}"
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
# (via is_allowed_local_path without sdk_cwd). Regular files live on
# the sandbox, not the host.
# ---------------------------------------------------------------------------
@@ -119,7 +134,7 @@ class TestReadLocal:
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
@@ -127,6 +142,25 @@ class TestReadLocal:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_tool_outputs_file(self):
"""Reading a tool-outputs file should also succeed."""
encoded = "-tmp-copilot-e2b-test-read-outputs"
tool_outputs_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
)
os.makedirs(tool_outputs_dir, exist_ok=True)
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
with open(filepath, "w") as f:
f.write('{"data": "test"}\n')
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
assert result["isError"] is False
assert "test" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
@@ -335,3 +369,199 @@ class TestSandboxWrite:
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
decoded = base64.b64decode(encoded_in_cmd).decode()
assert decoded == content
# ---------------------------------------------------------------------------
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
# ---------------------------------------------------------------------------
def _make_bridge_sandbox() -> SimpleNamespace:
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
files = SimpleNamespace(write=AsyncMock())
return SimpleNamespace(commands=commands, files=files)
class TestBridgeToSandbox:
@pytest.mark.asyncio
async def test_happy_path_small_file(self, tmp_path):
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
f = tmp_path / "result.json"
f.write_text('{"ok": true}')
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f))
assert result == expected
sandbox.commands.run.assert_called_once()
cmd = sandbox.commands.run.call_args[0][0]
assert "result.json" in cmd
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_skip_when_offset_nonzero(self, tmp_path):
"""Bridging is skipped when offset != 0 (partial read)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
)
assert result is None
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_skip_when_limit_too_small(self, tmp_path):
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_nonexistent_file_does_not_raise(self, tmp_path):
"""Bridging a non-existent file logs but does not propagate errors."""
sandbox = _make_bridge_sandbox()
await bridge_to_sandbox(
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
)
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_sandbox_write_failure_returns_none(self, tmp_path):
"""If sandbox write fails, returns None (best-effort)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
assert result is None
@pytest.mark.asyncio
async def test_large_file_uses_files_api(self, tmp_path):
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
f = tmp_path / "big.json"
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f), prefix="/home/user")
assert result == expected
sandbox.files.write.assert_called_once()
call_args = sandbox.files.write.call_args[0]
assert call_args[0] == expected
sandbox.commands.run.assert_not_called()
@pytest.mark.asyncio
async def test_small_binary_file_preserves_bytes(self, tmp_path):
"""A small binary file is bridged to /tmp via base64 without corruption."""
binary_data = bytes(range(256))
f = tmp_path / "image.png"
f.write_bytes(binary_data)
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f))
assert result == expected
sandbox.commands.run.assert_called_once()
cmd = sandbox.commands.run.call_args[0][0]
assert "base64" in cmd
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
"""A large binary file is bridged to /home/user/ as raw bytes."""
binary_data = bytes(range(256)) * 200
f = tmp_path / "photo.jpg"
f.write_bytes(binary_data)
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f), prefix="/home/user")
assert result == expected
sandbox.files.write.assert_called_once()
call_args = sandbox.files.write.call_args[0]
assert call_args[0] == expected
assert call_args[1] == binary_data
sandbox.commands.run.assert_not_called()
@pytest.mark.asyncio
async def test_very_large_file_skipped(self, tmp_path):
"""Files > 50 MB are skipped entirely."""
f = tmp_path / "huge.bin"
# Create a sparse file to avoid actually writing 50 MB
with open(f, "wb") as fh:
fh.seek(_BRIDGE_SKIP_BYTES + 1)
fh.write(b"\0")
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
assert result is None
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
# ---------------------------------------------------------------------------
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
# ---------------------------------------------------------------------------
class TestBridgeAndAnnotate:
@pytest.mark.asyncio
async def test_returns_annotation_on_success(self, tmp_path):
"""On success, returns a newline-prefixed annotation with the sandbox path."""
f = tmp_path / "data.json"
f.write_text('{"ok": true}')
sandbox = _make_bridge_sandbox()
annotation = await bridge_and_annotate(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected_path = _expected_bridge_path(str(f))
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
@pytest.mark.asyncio
async def test_returns_none_when_skipped(self, tmp_path):
"""When bridging is skipped (e.g. offset != 0), returns None."""
f = tmp_path / "data.json"
f.write_text("content")
sandbox = _make_bridge_sandbox()
annotation = await bridge_and_annotate(
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
)
assert annotation is None

View File

@@ -275,7 +275,7 @@ class TestCompactionE2E:
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
session = ChatSession.new(user_id="test-user", dry_run=False)
tracker.on_compact(str(session_file))
# --- Step 8: Next SDK message arrives → emit_start ---
@@ -376,7 +376,7 @@ class TestCompactionE2E:
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
session = ChatSession.new(user_id="test", dry_run=False)
builder = TranscriptBuilder()
# --- First query with compaction ---

View File

@@ -0,0 +1,82 @@
"""SDK environment variable builder — importable without circular deps.
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
can reuse the same subscription / OpenRouter / direct-Anthropic logic
without pulling in the full copilot service module (which would create a
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
"""
from __future__ import annotations
from backend.copilot.config import ChatConfig
from backend.copilot.sdk.subscription import validate_subscription
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
# A singleton would require importing service.py which causes the circular dep
# this module was created to avoid.
config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
sdk_cwd: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
the CLI writes temp/sub-agent output inside the per-session workspace
directory rather than an inaccessible system temp path.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
env: dict[str, str] = {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
env = {}
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env

View File

@@ -0,0 +1,293 @@
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
from unittest.mock import patch
import pytest
from backend.copilot.config import ChatConfig
# ---------------------------------------------------------------------------
# Helpers — build a ChatConfig with explicit field values so tests don't
# depend on real environment variables.
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
}
defaults.update(overrides)
return ChatConfig(**defaults)
# ---------------------------------------------------------------------------
# Mode 1 — Subscription auth
# ---------------------------------------------------------------------------
class TestBuildSdkEnvSubscription:
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_returns_blanked_keys(self, mock_validate):
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
mock_validate.assert_called_once()
@patch(
"backend.copilot.sdk.env.validate_subscription",
side_effect=RuntimeError("CLI not found"),
)
def test_propagates_validation_error(self, mock_validate):
"""If validate_subscription fails, the error bubbles up."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
with pytest.raises(RuntimeError, match="CLI not found"):
build_sdk_env()
# ---------------------------------------------------------------------------
# Mode 2 — Direct Anthropic (no OpenRouter)
# ---------------------------------------------------------------------------
class TestBuildSdkEnvDirectAnthropic:
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
def test_returns_empty_dict_when_openrouter_inactive(self):
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
# Force api_key to None after construction (field_validator may pick up env vars)
object.__setattr__(cfg, "api_key", None)
assert not cfg.openrouter_active
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
# ---------------------------------------------------------------------------
# Mode 3 — OpenRouter proxy
# ---------------------------------------------------------------------------
class TestBuildSdkEnvOpenRouter:
"""When OpenRouter is active, return proxy env vars."""
def _openrouter_config(self, **overrides):
defaults = {
"use_openrouter": True,
"api_key": "sk-or-test-key",
"base_url": "https://openrouter.ai/api/v1",
}
defaults.update(overrides)
return _make_config(**defaults)
def test_basic_openrouter_env(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
assert result["ANTHROPIC_API_KEY"] == ""
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
def test_strips_trailing_v1(self):
"""The /v1 suffix is stripped from the base URL."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_strips_trailing_v1_and_slash(self):
"""Trailing slash before /v1 strip is handled."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# rstrip("/") first, then remove /v1
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_no_v1_suffix_left_alone(self):
"""A base URL without /v1 is used as-is."""
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
def test_session_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="sess-123")
assert "ANTHROPIC_CUSTOM_HEADERS" in result
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_user_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(user_id="user-456")
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_both_headers(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="s1", user_id="u2")
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
assert "x-session-id: s1" in headers
assert "x-user-id: u2" in headers
# They should be newline-separated
assert "\n" in headers
def test_header_sanitisation_strips_newlines(self):
"""Newlines/carriage-returns in header values are stripped."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="bad\r\nvalue")
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
# The _safe helper removes \r and \n
assert "\r" not in header_val.split(": ", 1)[1]
assert "badvalue" in header_val
def test_header_value_truncated_to_128_chars(self):
"""Header values are truncated to 128 characters."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
long_id = "x" * 200
result = build_sdk_env(session_id=long_id)
# The value after "x-session-id: " should be at most 128 chars
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
value = header_line.split(": ", 1)[1]
assert len(value) == 128
# ---------------------------------------------------------------------------
# Mode priority
# ---------------------------------------------------------------------------
class TestBuildSdkEnvModePriority:
"""Subscription mode takes precedence over OpenRouter."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_subscription_overrides_openrouter(self, mock_validate):
cfg = _make_config(
use_claude_code_subscription=True,
use_openrouter=True,
api_key="sk-or-key",
base_url="https://openrouter.ai/api/v1",
)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# Should get subscription result, not OpenRouter
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# ---------------------------------------------------------------------------
# CLAUDE_CODE_TMPDIR integration
# ---------------------------------------------------------------------------
class TestClaudeCodeTmpdir:
"""Verify build_sdk_env() sets CLAUDE_CODE_TMPDIR from *sdk_cwd*."""
def test_tmpdir_set_when_sdk_cwd_is_truthy(self):
"""CLAUDE_CODE_TMPDIR is set to sdk_cwd when sdk_cwd is truthy."""
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(sdk_cwd="/tmp/copilot-workspace")
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/copilot-workspace"
def test_tmpdir_not_set_when_sdk_cwd_is_none(self):
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is None."""
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(sdk_cwd=None)
assert "CLAUDE_CODE_TMPDIR" not in result
def test_tmpdir_not_set_when_sdk_cwd_is_empty_string(self):
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is empty string."""
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(sdk_cwd="")
assert "CLAUDE_CODE_TMPDIR" not in result
@patch("backend.copilot.sdk.env.validate_subscription")
def test_tmpdir_set_in_subscription_mode(self, mock_validate):
"""CLAUDE_CODE_TMPDIR is set even in subscription mode."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(sdk_cwd="/tmp/sub-workspace")
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/sub-workspace"
assert result["ANTHROPIC_API_KEY"] == ""

View File

@@ -38,7 +38,7 @@ class TestFlattenAssistantContent:
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
assert _flatten_assistant_content(blocks) == ""
def test_mixed_blocks(self):
blocks = [
@@ -47,19 +47,22 @@ class TestFlattenAssistantContent:
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: Read]" in result
# tool_use blocks are dropped entirely to prevent model mimicry
assert "Read" not in result
def test_raw_strings(self):
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_unknown_block_type_preserved_as_placeholder(self):
def test_unknown_block_type_dropped(self):
blocks = [
{"type": "text", "text": "See this image:"},
{"type": "image", "source": {"type": "base64", "data": "..."}},
]
result = _flatten_assistant_content(blocks)
assert "See this image:" in result
assert "[__image__]" in result
# Unknown block types are dropped to prevent model mimicry
assert "[__image__]" not in result
assert "base64" not in result
def test_empty(self):
assert _flatten_assistant_content([]) == ""
@@ -279,7 +282,8 @@ class TestTranscriptToMessages:
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert "Let me check." in messages[0]["content"]
assert "[tool_use: read_file]" in messages[0]["content"]
# tool_use blocks are dropped entirely to prevent model mimicry
assert "read_file" not in messages[0]["content"]
assert messages[1]["content"] == "file contents"

View File

@@ -49,22 +49,22 @@ def test_format_assistant_tool_calls():
)
]
result = _format_conversation_context(msgs)
assert result is not None
assert 'You called tool: search({"q": "test"})' in result
# Assistant with no content and tool_calls omitted produces no lines
assert result is None
def test_format_tool_result():
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
result = _format_conversation_context(msgs)
assert result is not None
assert 'Tool result: {"result": "ok"}' in result
assert 'Tool output: {"result": "ok"}' in result
def test_format_tool_result_none_content():
msgs = [ChatMessage(role="tool", content=None)]
result = _format_conversation_context(msgs)
assert result is not None
assert "Tool result: " in result
assert "Tool output: " in result
def test_format_full_conversation():
@@ -84,8 +84,8 @@ def test_format_full_conversation():
assert result is not None
assert "User: find agents" in result
assert "You responded: I'll search for agents." in result
assert "You called tool: find_agents" in result
assert "Tool result:" in result
# tool_calls are omitted to prevent model mimicry
assert "Tool output:" in result
assert "You responded: Found Agent1." in result

View File

@@ -27,6 +27,7 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -76,6 +77,12 @@ class SDKResponseAdapter:
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif sdk_message.subtype == "task_progress":
# Emit a heartbeat so publish_chunk is called during long
# sub-agent runs. Without this, the Redis stream and meta
# key TTLs expire during gaps where no real chunks are
# produced (task_progress events were previously silent).
responses.append(StreamHeartbeat())
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage

View File

@@ -18,6 +18,7 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -28,6 +29,7 @@ from backend.copilot.response_model import (
StreamToolOutputAvailable,
)
from .compaction import compaction_events
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
from .tool_adapter import _pending_tool_outputs as _pto
@@ -59,6 +61,14 @@ def test_system_non_init_emits_nothing():
assert results == []
def test_task_progress_emits_heartbeat():
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
assert len(results) == 1
assert isinstance(results[0], StreamHeartbeat)
# -- AssistantMessage with TextBlock -----------------------------------------
@@ -680,3 +690,102 @@ def test_already_resolved_tool_skipped_in_user_message():
assert (
len(output_events) == 0
), "Already-resolved tool should not emit duplicate output"
# -- _end_text_if_open before compaction -------------------------------------
def test_end_text_if_open_emits_text_end_before_finish_step():
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
When ``emit_end_if_ready`` fires compaction events while a text block is
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
and raises "Received text-end for missing text part".
"""
adapter = _adapter()
# Open a text block by processing an AssistantMessage with text
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
adapter.convert_message(msg)
assert adapter.has_started_text
assert not adapter.has_ended_text
# Simulate what service.py does before yielding compaction events
pre_close: list[StreamBaseResponse] = []
adapter._end_text_if_open(pre_close)
combined = pre_close + list(compaction_events("Compacted transcript"))
text_end_idx = next(
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
)
finish_step_idx = next(
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
)
assert text_end_idx is not None, "StreamTextEnd must be present"
assert finish_step_idx is not None, "StreamFinishStep must be present"
assert text_end_idx < finish_step_idx, (
f"StreamTextEnd (idx={text_end_idx}) must precede "
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
"clears activeTextParts before text-end arrives"
)
def test_step_open_must_reset_after_compaction_finish_step():
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
Compaction events bypass the adapter, so service.py must explicitly clear
step_open after yielding a StreamFinishStep from compaction. Without this,
the next AssistantMessage skips StreamStartStep because the adapter still
thinks a step is open.
"""
adapter = _adapter()
# Open a step + text block via an AssistantMessage
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
adapter.convert_message(msg)
assert adapter.step_open is True
# Simulate what service.py does: close text, then check compaction events
pre_close: list[StreamBaseResponse] = []
adapter._end_text_if_open(pre_close)
events = list(compaction_events("Compacted transcript"))
if any(isinstance(ev, StreamFinishStep) for ev in events):
adapter.step_open = False
assert (
adapter.step_open is False
), "step_open must be False after compaction emits StreamFinishStep"
# Next AssistantMessage must open a new step
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
results = adapter.convert_message(msg2)
assert any(
isinstance(r, StreamStartStep) for r in results
), "A new StreamStartStep must be emitted after compaction closed the step"
def test_end_text_if_open_no_op_when_no_text_open():
"""_end_text_if_open emits nothing when no text block is open."""
adapter = _adapter()
results: list[StreamBaseResponse] = []
adapter._end_text_if_open(results)
assert results == []
def test_end_text_if_open_no_op_after_text_already_ended():
"""_end_text_if_open emits nothing when the text block is already closed."""
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
adapter.convert_message(msg)
# Close it once
first: list[StreamBaseResponse] = []
adapter._end_text_if_open(first)
assert len(first) == 1
assert isinstance(first[0], StreamTextEnd)
# Second call must be a no-op
second: list[StreamBaseResponse] = []
adapter._end_text_if_open(second)
assert second == []

View File

@@ -904,14 +904,14 @@ class TestTranscriptEdgeCases:
assert restored[1]["content"] == "Second"
def test_flatten_assistant_with_only_tool_use(self):
"""Assistant message with only tool_use blocks (no text)."""
"""Assistant message with only tool_use blocks (no text) flattens to empty."""
blocks = [
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
]
result = _flatten_assistant_content(blocks)
assert "[tool_use: bash]" in result
assert "[tool_use: read]" in result
# tool_use blocks are dropped entirely to prevent model mimicry
assert result == ""
def test_flatten_tool_result_nested_image(self):
"""Tool result containing image blocks uses placeholder."""
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
(f"{_SVC}.build_sdk_env", dict(return_value={})),
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
(f"{_SVC}.set_execution_context", {}),
(
@@ -1414,3 +1414,261 @@ class TestStreamChatCompletionRetryIntegration:
# Verify user-friendly message (not raw SDK text)
assert "Authentication" in errors[0].errorText
assert any(isinstance(e, StreamStart) for e in events)
@pytest.mark.asyncio
async def test_result_message_prompt_too_long_triggers_compaction(self):
"""CLI returns ResultMessage(subtype="error") with "Prompt is too long".
When the Claude CLI rejects the prompt pre-API (model=<synthetic>,
duration_api_ms=0), it sends a ResultMessage with is_error=True
instead of raising a Python exception. The retry loop must still
detect this as a context-length error and trigger compaction.
"""
import contextlib
from claude_agent_sdk import ResultMessage
from backend.copilot.response_model import StreamError, StreamStart
from backend.copilot.sdk.service import stream_chat_completion_sdk
session = self._make_session()
success_result = self._make_result_message()
attempt_count = [0]
error_result = ResultMessage(
subtype="error",
result="Prompt is too long",
duration_ms=100,
duration_api_ms=0,
is_error=True,
num_turns=0,
session_id="test-session-id",
)
def _client_factory(*args, **kwargs):
attempt_count[0] += 1
if attempt_count[0] == 1:
# First attempt: CLI returns error ResultMessage
return self._make_client_mock(result_message=error_result)
# Second attempt (after compaction): succeeds
return self._make_client_mock(result_message=success_result)
original_transcript = _build_transcript(
[("user", "prior question"), ("assistant", "prior answer")]
)
compacted_transcript = _build_transcript(
[("user", "[summary]"), ("assistant", "summary reply")]
)
patches = _make_sdk_patches(
session,
original_transcript=original_transcript,
compacted_transcript=compacted_transcript,
client_side_effect=_client_factory,
)
events = []
with contextlib.ExitStack() as stack:
for target, kwargs in patches:
stack.enter_context(patch(target, **kwargs))
async for event in stream_chat_completion_sdk(
session_id="test-session-id",
message="hello",
is_user_message=True,
user_id="test-user",
session=session,
):
events.append(event)
assert attempt_count[0] == 2, (
f"Expected 2 SDK attempts (CLI error ResultMessage "
f"should trigger compaction retry), got {attempt_count[0]}"
)
errors = [e for e in events if isinstance(e, StreamError)]
assert not errors, f"Unexpected StreamError: {errors}"
assert any(isinstance(e, StreamStart) for e in events)
@pytest.mark.asyncio
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
self,
):
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
The SDK internally compacts but the transcript is still too long. It
returns subtype="success" (process completed) with result="Prompt is
too long" (the actual rejection message). The retry loop must detect
this as a context-length error and trigger compaction — the subtype
"success" must not fool it into treating this as a real response.
"""
import contextlib
from claude_agent_sdk import ResultMessage
from backend.copilot.response_model import StreamError, StreamStart
from backend.copilot.sdk.service import stream_chat_completion_sdk
session = self._make_session()
success_result = self._make_result_message()
attempt_count = [0]
error_result = ResultMessage(
subtype="success",
result="Prompt is too long",
duration_ms=100,
duration_api_ms=0,
is_error=False,
num_turns=1,
session_id="test-session-id",
)
def _client_factory(*args, **kwargs):
attempt_count[0] += 1
async def _receive_error():
yield error_result
async def _receive_success():
yield success_result
client = MagicMock()
client._transport = MagicMock()
client._transport.write = AsyncMock()
client.query = AsyncMock()
if attempt_count[0] == 1:
client.receive_response = _receive_error
else:
client.receive_response = _receive_success
cm = AsyncMock()
cm.__aenter__.return_value = client
cm.__aexit__.return_value = None
return cm
original_transcript = _build_transcript(
[("user", "prior question"), ("assistant", "prior answer")]
)
compacted_transcript = _build_transcript(
[("user", "[summary]"), ("assistant", "summary reply")]
)
patches = _make_sdk_patches(
session,
original_transcript=original_transcript,
compacted_transcript=compacted_transcript,
client_side_effect=_client_factory,
)
events = []
with contextlib.ExitStack() as stack:
for target, kwargs in patches:
stack.enter_context(patch(target, **kwargs))
async for event in stream_chat_completion_sdk(
session_id="test-session-id",
message="hello",
is_user_message=True,
user_id="test-user",
session=session,
):
events.append(event)
assert attempt_count[0] == 2, (
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
f"result should trigger compaction retry), got {attempt_count[0]}"
)
errors = [e for e in events if isinstance(e, StreamError)]
assert not errors, f"Unexpected StreamError: {errors}"
assert any(isinstance(e, StreamStart) for e in events)
@pytest.mark.asyncio
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
self,
):
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
The SDK returns error type "invalid_request" but puts the actual
rejection message ("Prompt is too long") in the content blocks.
The retry loop must detect this via content inspection (sdk_error
being set confirms it's an error message, not user content).
"""
import contextlib
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from backend.copilot.response_model import StreamError, StreamStart
from backend.copilot.sdk.service import stream_chat_completion_sdk
session = self._make_session()
success_result = self._make_result_message()
attempt_count = [0]
def _client_factory(*args, **kwargs):
attempt_count[0] += 1
async def _receive_error():
# SDK returns invalid_request with "Prompt is too long" in content.
# ResultMessage.result is a non-PTL value ("done") to isolate
# the AssistantMessage content detection path exclusively.
yield AssistantMessage(
content=[TextBlock(text="Prompt is too long")],
model="<synthetic>",
error="invalid_request",
)
yield ResultMessage(
subtype="success",
result="done",
duration_ms=100,
duration_api_ms=0,
is_error=False,
num_turns=1,
session_id="test-session-id",
)
async def _receive_success():
yield success_result
client = MagicMock()
client._transport = MagicMock()
client._transport.write = AsyncMock()
client.query = AsyncMock()
if attempt_count[0] == 1:
client.receive_response = _receive_error
else:
client.receive_response = _receive_success
cm = AsyncMock()
cm.__aenter__.return_value = client
cm.__aexit__.return_value = None
return cm
original_transcript = _build_transcript(
[("user", "prior question"), ("assistant", "prior answer")]
)
compacted_transcript = _build_transcript(
[("user", "[summary]"), ("assistant", "summary reply")]
)
patches = _make_sdk_patches(
session,
original_transcript=original_transcript,
compacted_transcript=compacted_transcript,
client_side_effect=_client_factory,
)
events = []
with contextlib.ExitStack() as stack:
for target, kwargs in patches:
stack.enter_context(patch(target, **kwargs))
async for event in stream_chat_completion_sdk(
session_id="test-session-id",
message="hello",
is_user_message=True,
user_id="test-user",
session=session,
):
events.append(event)
assert attempt_count[0] == 2, (
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
f"too long' should trigger compaction retry), got {attempt_count[0]}"
)
errors = [e for e in events if isinstance(e, StreamError)]
assert not errors, f"Unexpected StreamError: {errors}"
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -22,6 +22,38 @@ from .tool_adapter import (
logger = logging.getLogger(__name__)
# The SDK CLI uses "Task" in older versions and "Agent" in v2.x+.
# Shared across all sessions — used by security hooks for sub-agent detection.
_SUBAGENT_TOOLS: frozenset[str] = frozenset({"Task", "Agent"})
# Unicode ranges stripped by _sanitize():
# - BiDi overrides (U+202A-U+202E, U+2066-U+2069) can trick reviewers
# into misreading code/logs.
# - Zero-width characters (U+200B-U+200F, U+FEFF) can hide content.
_BIDI_AND_ZW_CHARS = set(
chr(c)
for r in (range(0x202A, 0x202F), range(0x2066, 0x206A), range(0x200B, 0x2010))
for c in r
) | {"\ufeff"}
def _sanitize(value: str, max_len: int = 200) -> str:
"""Strip control characters and truncate for safe logging.
Removes C0 (U+0000-U+001F), DEL (U+007F), C1 (U+0080-U+009F),
Unicode BiDi overrides, and zero-width characters to prevent
log injection and visual spoofing.
"""
cleaned = "".join(
c
for c in value
if c >= " "
and c != "\x7f"
and not ("\x80" <= c <= "\x9f")
and c not in _BIDI_AND_ZW_CHARS
)
return cleaned[:max_len]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
@@ -136,11 +168,13 @@ def create_security_hooks(
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- SubagentStart: Log sub-agent lifecycle start
- SubagentStop: Log sub-agent lifecycle end
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
max_subtasks: Maximum concurrent sub-agent spawns allowed per session
on_compact: Callback invoked when SDK starts compacting context.
Receives the transcript_path from the hook input.
@@ -151,9 +185,19 @@ def create_security_hooks(
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session tracking for Task sub-agent concurrency.
# Per-session tracking for sub-agent concurrency.
# Set of tool_use_ids that consumed a slot — len() is the active count.
task_tool_use_ids: set[str] = set()
#
# LIMITATION: For background (async) agents the SDK returns the
# Agent/Task tool immediately with {isAsync: true}, which triggers
# PostToolUse and releases the slot while the agent is still running.
# SubagentStop fires later when the background process finishes but
# does not currently hold a slot. This means the concurrency limit
# only gates *launches*, not true concurrent execution. To fix this
# we would need to track background agent_ids separately and release
# in SubagentStop, but the SDK does not guarantee SubagentStop fires
# for every background agent (e.g. on session abort).
subagent_tool_use_ids: set[str] = set()
async def pre_tool_use_hook(
input_data: HookInput,
@@ -165,29 +209,22 @@ def create_security_hooks(
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
"Background task execution is not supported. "
"Run tasks in the foreground instead "
"(remove the run_in_background parameter)."
),
)
if len(task_tool_use_ids) >= max_subtasks:
# Rate-limit sub-agent spawns per session.
# The SDK CLI renamed "Task" → "Agent" in v2.x; handle both.
if tool_name in _SUBAGENT_TOOLS:
# Background agents are allowed — the SDK returns immediately
# with {isAsync: true} and the model polls via TaskOutput.
# Still count them against the concurrency limit.
if len(subagent_tool_use_ids) >= max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
f"[SDK] Sub-agent limit reached ({max_subtasks}), "
f"user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
f"Maximum {max_subtasks} concurrent sub-agents. "
"Wait for running sub-agents to finish, "
"or continue in the main conversation."
),
)
@@ -208,20 +245,20 @@ def create_security_hooks(
if result:
return cast(SyncHookJSONOutput, result)
# Reserve the Task slot only after all validations pass
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
# Reserve the sub-agent slot only after all validations pass
if tool_name in _SUBAGENT_TOOLS and tool_use_id is not None:
subagent_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a Task concurrency slot if one was reserved."""
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
task_tool_use_ids.discard(tool_use_id)
def _release_subagent_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a sub-agent concurrency slot if one was reserved."""
if tool_name in _SUBAGENT_TOOLS and tool_use_id in subagent_tool_use_ids:
subagent_tool_use_ids.discard(tool_use_id)
logger.info(
"[SDK] Task slot released, active=%d/%d, user=%s",
len(task_tool_use_ids),
"[SDK] Sub-agent slot released, active=%d/%d, user=%s",
len(subagent_tool_use_ids),
max_subtasks,
user_id,
)
@@ -241,13 +278,14 @@ def create_security_hooks(
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
_release_task_slot(tool_name, tool_use_id)
_release_subagent_slot(tool_name, tool_use_id)
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
safe_tool_use_id = _sanitize(str(tool_use_id or ""), max_len=12)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
tool_name,
is_builtin,
(tool_use_id or "")[:12],
safe_tool_use_id,
)
# Stash output for SDK built-in tools so the response adapter can
@@ -256,7 +294,7 @@ def create_security_hooks(
if is_builtin:
tool_response = input_data.get("tool_response")
if tool_response is not None:
resp_preview = str(tool_response)[:100]
resp_preview = _sanitize(str(tool_response), max_len=100)
logger.info(
"[SDK] Stashing builtin output for %s (%d chars): %s...",
tool_name,
@@ -280,13 +318,17 @@ def create_security_hooks(
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
error = _sanitize(str(input_data.get("error", "Unknown error")))
safe_tool_use_id = _sanitize(str(tool_use_id or ""))
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
tool_name,
error,
user_id,
safe_tool_use_id,
)
_release_task_slot(tool_name, tool_use_id)
_release_subagent_slot(tool_name, tool_use_id)
return cast(SyncHookJSONOutput, {})
@@ -301,20 +343,17 @@ def create_security_hooks(
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")
.replace("\r", "")
transcript_path = _sanitize(
str(input_data.get("transcript_path", "")), max_len=500
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
trigger,
user_id,
transcript_path,
@@ -323,6 +362,44 @@ def create_security_hooks(
on_compact(transcript_path)
return cast(SyncHookJSONOutput, {})
async def subagent_start_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when a sub-agent starts execution."""
_ = context, tool_use_id
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
logger.info(
"[SDK] SubagentStart: agent_id=%s, type=%s, user=%s",
agent_id,
agent_type,
user_id,
)
return cast(SyncHookJSONOutput, {})
async def subagent_stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when a sub-agent stops."""
_ = context, tool_use_id
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
transcript = _sanitize(
str(input_data.get("agent_transcript_path", "")), max_len=500
)
logger.info(
"[SDK] SubagentStop: agent_id=%s, type=%s, user=%s, transcript=%s",
agent_id,
agent_type,
user_id,
transcript,
)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
@@ -330,6 +407,8 @@ def create_security_hooks(
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
"SubagentStart": [HookMatcher(matcher="*", hooks=[subagent_start_hook])],
"SubagentStop": [HookMatcher(matcher="*", hooks=[subagent_stop_hook])],
}
return hooks

View File

@@ -5,13 +5,18 @@ They validate that the security hooks correctly block unauthorized paths,
tool access, and dangerous input patterns.
"""
import logging
import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .security_hooks import (
_validate_tool_access,
_validate_user_isolation,
create_security_hooks,
)
SDK_CWD = "/tmp/copilot-abc123"
@@ -132,8 +137,20 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_tool_outputs_allowed():
"""tool-outputs/ paths should be allowed, same as tool-results/."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-outputs/12345.txt"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
finally:
_current_project_dir.reset(token)
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/tool-outputs is."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
@@ -220,8 +237,6 @@ def test_bash_builtin_blocked_message_clarity():
@pytest.fixture()
def _hooks():
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
post = hooks["PostToolUse"][0].hooks[0]
@@ -231,16 +246,15 @@ def _hooks():
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_background_blocked(_hooks):
"""Task with run_in_background=true must be denied."""
async def test_task_background_allowed(_hooks):
"""Task with run_in_background=true is allowed (SDK handles async lifecycle)."""
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
tool_use_id=None,
tool_use_id="tu-bg-1",
context={},
)
assert _is_denied(result)
assert "foreground" in _reason(result).lower()
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@@ -354,3 +368,303 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# ---------------------------------------------------------------------------
# "Agent" tool name (SDK v2.x+ renamed "Task" → "Agent")
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_agent_background_allowed(_hooks):
"""Agent with run_in_background=true is allowed (SDK handles async lifecycle)."""
pre, _, _ = _hooks
result = await pre(
{
"tool_name": "Agent",
"tool_input": {"run_in_background": True, "prompt": "x"},
},
tool_use_id="tu-agent-bg-1",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_agent_foreground_allowed(_hooks):
"""Agent without run_in_background should be allowed."""
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "do stuff"}},
tool_use_id="tu-agent-1",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_background_agent_counts_against_limit(_hooks):
"""Background agents still consume concurrency slots."""
pre, _, _ = _hooks
# Two background agents fill the limit
for i in range(2):
result = await pre(
{
"tool_name": "Agent",
"tool_input": {"run_in_background": True, "prompt": "bg"},
},
tool_use_id=f"tu-bglimit-{i}",
context={},
)
assert not _is_denied(result)
# Third (background or foreground) should be denied
result = await pre(
{
"tool_name": "Agent",
"tool_input": {"run_in_background": True, "prompt": "over"},
},
tool_use_id="tu-bglimit-2",
context={},
)
assert _is_denied(result)
assert "Maximum" in _reason(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_agent_limit_enforced(_hooks):
"""Agent spawns beyond max_subtasks should be denied."""
pre, _, _ = _hooks
# First two should pass
for i in range(2):
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-agent-limit-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied (limit=2)
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "over limit"}},
tool_use_id="tu-agent-limit-2",
context={},
)
assert _is_denied(result)
assert "Maximum" in _reason(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_agent_slot_released_on_completion(_hooks):
"""Completing an Agent should free a slot so new Agents can be spawned."""
pre, post, _ = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-agent-comp-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied — at capacity
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
tool_use_id="tu-agent-comp-2",
context={},
)
assert _is_denied(result)
# Complete first agent — frees a slot
await post(
{"tool_name": "Agent", "tool_input": {}},
tool_use_id="tu-agent-comp-0",
context={},
)
# Now a new Agent should be allowed
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "after release"}},
tool_use_id="tu-agent-comp-3",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_agent_slot_released_on_failure(_hooks):
"""A failed Agent should also free its concurrency slot."""
pre, _, post_failure = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-agent-fail-{i}",
context={},
)
assert not _is_denied(result)
# At capacity
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
tool_use_id="tu-agent-fail-2",
context={},
)
assert _is_denied(result)
# Fail first agent — should free a slot
await post_failure(
{"tool_name": "Agent", "tool_input": {}, "error": "something broke"},
tool_use_id="tu-agent-fail-0",
context={},
)
# New Agent should be allowed
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "after failure"}},
tool_use_id="tu-agent-fail-3",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_mixed_task_agent_share_slots(_hooks):
"""Task and Agent share the same concurrency pool."""
pre, post, _ = _hooks
# Fill one slot with Task, one with Agent
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id="tu-mix-task",
context={},
)
assert not _is_denied(result)
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
tool_use_id="tu-mix-agent",
context={},
)
assert not _is_denied(result)
# Third (either name) should be denied
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
tool_use_id="tu-mix-over",
context={},
)
assert _is_denied(result)
# Release the Task slot
await post(
{"tool_name": "Task", "tool_input": {}},
tool_use_id="tu-mix-task",
context={},
)
# Now an Agent should be allowed
result = await pre(
{"tool_name": "Agent", "tool_input": {"prompt": "after task release"}},
tool_use_id="tu-mix-new",
context={},
)
assert not _is_denied(result)
# ---------------------------------------------------------------------------
# SubagentStart / SubagentStop hooks
# ---------------------------------------------------------------------------
@pytest.fixture()
def _subagent_hooks():
"""Create hooks and return (subagent_start, subagent_stop) handlers."""
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
start = hooks["SubagentStart"][0].hooks[0]
stop = hooks["SubagentStop"][0].hooks[0]
return start, stop
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_subagent_start_hook_returns_empty(_subagent_hooks):
"""SubagentStart hook should return an empty dict (logging only)."""
start, _ = _subagent_hooks
result = await start(
{"agent_id": "sa-123", "agent_type": "research"},
tool_use_id=None,
context={},
)
assert result == {}
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_subagent_stop_hook_returns_empty(_subagent_hooks):
"""SubagentStop hook should return an empty dict (logging only)."""
_, stop = _subagent_hooks
result = await stop(
{
"agent_id": "sa-123",
"agent_type": "research",
"agent_transcript_path": "/tmp/transcript.txt",
},
tool_use_id=None,
context={},
)
assert result == {}
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
"""SubagentStart/Stop should sanitize control chars from inputs."""
start, stop = _subagent_hooks
# Inject control characters (C0, DEL, C1, BiDi overrides, zero-width)
# — hook should not raise AND logs must be clean
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
result = await start(
{
"agent_id": "sa\n-injected\r\x00\x7f",
"agent_type": "safe\x80_type\x9f\ttab",
},
tool_use_id=None,
context={},
)
assert result == {}
# Control chars must be stripped from the logged values
for record in caplog.records:
assert "\x00" not in record.message
assert "\r" not in record.message
assert "\n" not in record.message
assert "\x7f" not in record.message
assert "\x80" not in record.message
assert "\x9f" not in record.message
assert "safe_type" in caplog.text
caplog.clear()
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
result = await stop(
{
"agent_id": "sa\n-injected\x7f",
"agent_type": "type\r\x80\x9f",
"agent_transcript_path": "/tmp/\x00malicious\npath\u202a\u200b",
},
tool_use_id=None,
context={},
)
assert result == {}
for record in caplog.records:
assert "\x00" not in record.message
assert "\r" not in record.message
assert "\n" not in record.message
assert "\x7f" not in record.message
assert "\u202a" not in record.message
assert "\u200b" not in record.message
assert "/tmp/maliciouspath" in caplog.text

View File

@@ -59,11 +59,14 @@ from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamStatus,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
@@ -77,15 +80,13 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
cancel_pending_tool_tasks,
create_copilot_mcp_server,
get_copilot_tool_names,
get_sdk_disallowed_tools,
pre_launch_tool_call,
reset_stash_event,
reset_tool_failure_counters,
set_execution_context,
@@ -115,9 +116,10 @@ _MAX_STREAM_ATTEMPTS = 3
# Hard circuit breaker: abort the stream if the model sends this many
# consecutive tool calls with empty parameters (a sign of context
# saturation or serialization failure). Empty input ({}) is never
# legitimate — even one is suspicious, three is conclusive.
_EMPTY_TOOL_CALL_LIMIT = 3
# saturation or serialization failure). The MCP wrapper now returns
# guidance on the first empty call, giving the model a chance to
# self-correct. The limit is generous to allow recovery attempts.
_EMPTY_TOOL_CALL_LIMIT = 5
# User-facing error shown when the empty-tool-call circuit breaker trips.
_CIRCUIT_BREAKER_ERROR_MSG = (
@@ -567,60 +569,6 @@ def _resolve_sdk_model() -> str | None:
return model
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
2. **Direct Anthropic** — returns `{}`; subprocess inherits
`ANTHROPIC_API_KEY` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
_validate_claude_code_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
# `openrouter_active` checks the flag *and* credential presence.
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
# Strip /v1 suffix — SDK expects the base URL without a version path.
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
"""Sanitise a header value: strip newlines/whitespace and cap length."""
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -800,15 +748,11 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
elif msg.role == "assistant":
if msg.content:
lines.append(f"You responded: {msg.content}")
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
tool_name = func.get("name", "unknown")
tool_args = func.get("arguments", "")
lines.append(f"You called tool: {tool_name}({tool_args})")
# Omit tool_calls — any text representation gets mimicked
# by the model. Tool results below provide the context.
elif msg.role == "tool":
content = msg.content or ""
lines.append(f"Tool result: {content}")
lines.append(f"Tool output: {content[:500]}")
if not lines:
return None
@@ -1268,6 +1212,14 @@ async def _run_stream_attempt(
consecutive_empty_tool_calls = 0
# --- Intermediate persistence tracking ---
# Flush session messages to DB periodically so page reloads show progress
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
_last_flush_time = time.monotonic()
_msgs_since_flush = 0
_FLUSH_INTERVAL_SECONDS = 30.0
_FLUSH_MESSAGE_THRESHOLD = 10
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
# suppress SDK cleanup errors that occur when the SSE client disconnects
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
@@ -1354,6 +1306,27 @@ async def _run_stream_attempt(
error_preview,
)
# Intercept prompt-too-long errors surfaced as
# AssistantMessage.error (not as a Python exception).
# Re-raise so the outer retry loop can compact the
# transcript and retry with reduced context.
# Check both error_text and error_preview: sdk_error
# being set confirms this is an error message (not user
# content), so checking content is safe. The actual
# error description (e.g. "Prompt is too long") may be
# in the content, not the error type field
# (e.g. error="invalid_request", content="Prompt is
# too long").
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
Exception(error_preview)
):
logger.warning(
"%s Prompt-too-long detected via AssistantMessage "
"error — raising for retry",
ctx.log_prefix,
)
raise RuntimeError("Prompt is too long")
# Intercept transient API errors (socket closed,
# ECONNRESET) — replace the raw message with a
# user-friendly error text and use the retryable
@@ -1381,28 +1354,17 @@ async def _run_stream_attempt(
ended_with_stream_error = True
break
# Parallel tool execution: pre-launch every ToolUseBlock as an
# asyncio.Task the moment its AssistantMessage arrives. The SDK
# sends one AssistantMessage per tool call when issuing parallel
# calls, so each message is pre-launched independently. The MCP
# handlers will await the already-running task instead of executing
# fresh, making all concurrent tool calls run in parallel.
#
# Also determine if the message is a tool-only batch (all content
# Determine if the message is a tool-only batch (all content
# items are ToolUseBlocks) — such messages have no text output yet,
# so we skip the wait_for_stash flush below.
#
# Note: parallel execution of tools is handled natively by the
# SDK CLI via readOnlyHint annotations on tool definitions.
is_tool_only = False
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
is_tool_only = True
# NOTE: Pre-launches are sequential (each await completes
# file-ref expansion before the next starts). This is fine
# since expansion is typically sub-ms; a future optimisation
# could gather all pre-launches concurrently.
for tool_use in sdk_msg.content:
if isinstance(tool_use, ToolUseBlock):
await pre_launch_tool_call(tool_use.name, tool_use.input)
else:
is_tool_only = False
is_tool_only = all(
isinstance(item, ToolUseBlock) for item in sdk_msg.content
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
@@ -1459,6 +1421,16 @@ async def _run_stream_attempt(
sdk_msg.result or "(no error message provided)",
)
# Check for prompt-too-long regardless of subtype — the
# SDK may return subtype="success" with result="Prompt is
# too long" when the CLI rejects the prompt before calling
# the API (cost_usd=0, no tokens consumed). If we only
# check the "error" subtype path, the stream appears to
# complete normally, the synthetic error text is stored
# in the transcript, and the session grows without bound.
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
raise RuntimeError("Prompt is too long")
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
@@ -1490,6 +1462,23 @@ async def _run_stream_attempt(
# Emit compaction end if SDK finished compacting.
# Sync TranscriptBuilder with the CLI's active context.
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
if compact_result.events:
# Compaction events end with StreamFinishStep, which maps to
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
# Close any open text block BEFORE the compaction events so
# the text-end arrives before finish-step, preventing
# "text-end for missing text part" errors on the frontend.
pre_close: list[StreamBaseResponse] = []
state.adapter._end_text_if_open(pre_close)
# Compaction events bypass the adapter, so sync step state
# when a StreamFinishStep is present — otherwise the adapter
# will skip StreamStartStep on the next AssistantMessage.
if any(
isinstance(ev, StreamFinishStep) for ev in compact_result.events
):
state.adapter.step_open = False
for r in pre_close:
yield r
for ev in compact_result.events:
yield ev
entries_replaced = False
@@ -1536,6 +1525,34 @@ async def _run_stream_attempt(
model=sdk_msg.model,
)
# --- Intermediate persistence ---
# Flush session messages to DB periodically so page reloads
# show progress during long-running turns.
_msgs_since_flush += 1
now = time.monotonic()
if (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
):
try:
await asyncio.shield(upsert_chat_session(ctx.session))
logger.debug(
"%s Intermediate flush: %d messages "
"(msgs_since=%d, elapsed=%.1fs)",
ctx.log_prefix,
len(ctx.session.messages),
_msgs_since_flush,
now - _last_flush_time,
)
except Exception as flush_err:
logger.warning(
"%s Intermediate flush failed: %s",
ctx.log_prefix,
flush_err,
)
_last_flush_time = now
_msgs_since_flush = 0
if acc.stream_completed:
break
finally:
@@ -1867,7 +1884,10 @@ async def stream_chat_completion_sdk(
)
# Fail fast when no API credentials are available at all.
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
# sdk_cwd routes the CLI's temp dir into the per-session workspace
# so sub-agent output files land inside sdk_cwd (see build_sdk_env).
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id, sdk_cwd=sdk_cwd)
if not config.api_key and not config.use_claude_code_subscription:
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY, "
@@ -2062,13 +2082,22 @@ async def stream_chat_completion_sdk(
try:
async for event in _run_stream_attempt(stream_ctx, state):
if not isinstance(event, StreamHeartbeat):
if not isinstance(
event,
(
StreamHeartbeat,
# Compaction UI events are cosmetic and must not
# block retry — they're emitted before the SDK
# query on compacted attempts.
StreamStartStep,
StreamFinishStep,
StreamToolInputStart,
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
events_yielded += 1
yield event
# Cancel any pre-launched tasks that were never dispatched
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
# with the three error-path await cancel_pending_tool_tasks() calls.
await cancel_pending_tool_tasks()
break # Stream completed — exit retry loop
except asyncio.CancelledError:
logger.warning(
@@ -2077,9 +2106,6 @@ async def stream_chat_completion_sdk(
attempt + 1,
_MAX_STREAM_ATTEMPTS,
)
# Cancel any pre-launched tasks so they don't continue executing
# against a rolled-back or abandoned session.
await cancel_pending_tool_tasks()
raise
except _HandledStreamError as exc:
# _run_stream_attempt already yielded a StreamError and
@@ -2111,8 +2137,6 @@ async def stream_chat_completion_sdk(
retryable=True,
)
ended_with_stream_error = True
# Cancel any pre-launched tasks from the failed attempt.
await cancel_pending_tool_tasks()
break
except Exception as e:
stream_err = e
@@ -2129,9 +2153,6 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
session.messages = session.messages[:pre_attempt_msg_count]
# Cancel any pre-launched tasks from the failed attempt so they
# don't continue executing against the rolled-back session.
await cancel_pending_tool_tasks()
if events_yielded > 0:
# Events were already sent to the frontend and cannot be
# unsent. Retrying would produce duplicate/inconsistent

View File

@@ -392,7 +392,7 @@ class TestFlattenThinkingBlocks:
assert result == ""
def test_mixed_thinking_text_tool(self):
"""Mixed blocks: only text and tool_use survive flattening."""
"""Mixed blocks: only text survives flattening; thinking and tool_use dropped."""
blocks = [
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
{"type": "redacted_thinking", "data": "xyz"},
@@ -403,7 +403,8 @@ class TestFlattenThinkingBlocks:
assert "hmm" not in result
assert "xyz" not in result
assert "I'll read the file." in result
assert "[tool_use: Read]" in result
# tool_use blocks are dropped entirely to prevent model mimicry
assert "Read" not in result
# ---------------------------------------------------------------------------

View File

@@ -14,6 +14,7 @@ from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from mcp.types import ToolAnnotations
from backend.copilot.context import (
_current_permissions,
@@ -37,7 +38,7 @@ from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS, bridge_and_annotate
if TYPE_CHECKING:
from e2b import AsyncSandbox
@@ -53,14 +54,6 @@ _MCP_MAX_CHARS = 500_000
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
# Initialised per-session in set_execution_context() so concurrent sessions
# never share the same dict.
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
ContextVar("_tool_task_queues", default=None)
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -115,7 +108,6 @@ def set_execution_context(
_current_permissions.set(permissions)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_tool_task_queues.set({})
_consecutive_tool_failures.set({})
@@ -132,48 +124,6 @@ def reset_stash_event() -> None:
event.clear()
async def cancel_pending_tool_tasks() -> None:
"""Cancel all queued pre-launched tasks for the current execution context.
Call this when a stream attempt aborts (error, cancellation) to prevent
pre-launched tasks from continuing to execute against a rolled-back session.
Tasks that are already done are skipped; in-flight tasks are cancelled and
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
before the next retry starts.
"""
queues = _tool_task_queues.get()
if not queues:
return
cancelled_tasks: list[asyncio.Task] = []
for tool_name, queue in list(queues.items()):
cancelled = 0
while not queue.empty():
task, _args = queue.get_nowait()
if not task.done():
task.cancel()
cancelled_tasks.append(task)
cancelled += 1
if cancelled:
logger.debug(
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
)
queues.clear()
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
# completes before the next retry attempt starts new pre-launches.
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
if cancelled_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*cancelled_tasks, return_exceptions=True),
timeout=5.0,
)
except TimeoutError:
logger.warning(
"Timed out waiting for %d cancelled task(s) to clean up",
len(cancelled_tasks),
)
def reset_tool_failure_counters() -> None:
"""Reset all tool-level circuit breaker counters.
@@ -249,10 +199,6 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
the timeout is purely a safety net for the case where the hook never fires.
Returns ``True`` if the stash signal was received, ``False`` on timeout.
The 2.0 s default was chosen to accommodate slower tool startup in cloud
sandboxes while still failing fast when the hook genuinely will not fire.
With the parallel pre-launch path, hooks typically fire well under 1 ms.
"""
event = _stash_event.get(None)
if event is None:
@@ -271,95 +217,13 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
return False
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
"""Pre-launch a tool as a background task so parallel calls run concurrently.
Called when an AssistantMessage with ToolUseBlocks is received, before the
SDK dispatches the MCP tool/call requests. The tool_handler will await the
pre-launched task instead of executing fresh.
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
the prefix is stripped automatically before looking up the tool.
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
in the same order as the ToolUseBlocks appear in the AssistantMessage.
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
given tool name dequeues the N-th pre-launched task — result and args always
correspond when the SDK preserves order (which it does in the current SDK).
"""
queues = _tool_task_queues.get()
if queues is None:
return
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
base_tool = TOOL_REGISTRY.get(bare_name)
if base_tool is None:
return
user_id, session = get_execution_context()
if session is None:
return
# Expand @@agptfile: references before launching the task.
# The _truncating wrapper (which normally handles expansion) runs AFTER
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
input_schema: dict[str, Any] | None
try:
input_schema = _build_input_schema(base_tool)
except Exception:
input_schema = None # schema unavailable — skip schema-aware expansion
try:
args = await expand_file_refs_in_args(
args, user_id, session, input_schema=input_schema
)
except FileRefExpansionError as exc:
logger.warning(
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
bare_name,
exc,
)
return
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
# Log unhandled exceptions so "Task exception was never retrieved" warnings
# do not pollute stderr when a task is pre-launched but never dequeued.
task.add_done_callback(
lambda t, name=bare_name: (
logger.warning(
"Pre-launched task for %s raised unhandled: %s",
name,
t.exception(),
)
if not t.cancelled() and t.exception()
else None
)
)
if bare_name not in queues:
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
# Store (task, args) so the handler can log a warning if the SDK dispatches
# calls in a different order than the ToolUseBlocks appeared in the message.
queues[bare_name].put_nowait((task, args))
async def _execute_tool_sync(
base_tool: BaseTool,
user_id: str | None,
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion should be performed by the caller before
invoking this function. For the normal (non-parallel) path it is handled
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
handled in :func:`pre_launch_tool_call` before the task is created.
"""
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -455,83 +319,7 @@ def create_tool_handler(base_tool: BaseTool):
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response.
If a pre-launched task exists (from parallel tool pre-launch in the
message loop), await it instead of executing fresh.
"""
queues = _tool_task_queues.get()
if queues and base_tool.name in queues:
queue = queues[base_tool.name]
if not queue.empty():
task, launch_args = queue.get_nowait()
# Sanity-check: warn if the args don't match — this can happen
# if the SDK dispatches tool calls in a different order than the
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
# could occur in future SDK versions or with SDK bugs).
# We compare full values (not just keys) so that two run_block
# calls with different block_id values are caught even though
# both have the same key set.
if launch_args != args:
logger.warning(
"Pre-launched task for %s: arg mismatch "
"(launch_keys=%s, call_keys=%s) — cancelling "
"pre-launched task and falling back to direct execution",
base_tool.name,
(
sorted(launch_args.keys())
if isinstance(launch_args, dict)
else type(launch_args).__name__
),
(
sorted(args.keys())
if isinstance(args, dict)
else type(args).__name__
),
)
if not task.done():
task.cancel()
# Await cancellation to prevent duplicate concurrent
# execution for blocks with side effects.
try:
await task
except (asyncio.CancelledError, Exception):
pass
# Fall through to the direct-execution path below.
else:
# Args match — await the pre-launched task.
try:
result = await task
except asyncio.CancelledError:
# Re-raise: CancelledError may be propagating from the
# outer streaming loop being cancelled — swallowing it
# would mask the cancellation and prevent proper cleanup.
logger.warning(
"Pre-launched tool %s was cancelled — re-raising",
base_tool.name,
)
raise
except Exception as e:
logger.error(
"Pre-launched tool %s failed: %s",
base_tool.name,
e,
exc_info=True,
)
return _mcp_error(
f"Failed to execute {base_tool.name}. "
"Check server logs for details."
)
# Pre-truncate the result so the _truncating wrapper (which
# wraps this handler) receives an already-within-budget
# value. _truncating handles stashing — we must NOT stash
# here or the output will be appended twice to the FIFO
# queue and pop_pending_tool_output would return a duplicate
# entry on the second call for the same tool.
return truncate(result, _MCP_MAX_CHARS)
# No pre-launched task — execute directly (fallback for non-parallel calls).
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session = get_execution_context()
if session is None:
@@ -599,7 +387,16 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
selected = list(itertools.islice(f, offset, offset + limit))
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return _mcp_ok("".join(selected))
#
# When E2B is active, also copy the file into the sandbox so
# bash_exec can process it (the model often uses Read then bash).
text = "".join(selected)
sandbox = _current_sandbox.get(None)
if sandbox is not None:
annotation = await bridge_and_annotate(sandbox, resolved, offset, limit)
if annotation:
text += annotation
return _mcp_ok(text)
except FileNotFoundError:
return _mcp_err(f"File not found: {file_path}")
except Exception as e:
@@ -648,9 +445,19 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
)
_PARALLEL_ANNOTATION = ToolAnnotations(readOnlyHint=True)
def create_copilot_mcp_server(*, use_e2b: bool = False):
"""Create an in-process MCP server configuration for CoPilot tools.
All tools are annotated with ``readOnlyHint=True`` so the SDK CLI
dispatches concurrent tool calls in parallel rather than sequentially.
This is a deliberate override: even side-effect tools use the hint
because the MCP tools are already individually sandboxed and the
pre-launch duplicate-execution bug (SECRT-2204) is worse than
sequential dispatch.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
@@ -668,6 +475,28 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
# Empty tool args = model's output was truncated by the API's
# max_tokens limit. Instead of letting the tool fail with a
# confusing error (and eventually tripping the circuit breaker),
# return clear guidance so the model can self-correct.
if not args and input_schema and input_schema.get("required"):
logger.warning(
"[MCP] %s called with empty args (likely output "
"token truncation) — returning guidance",
tool_name,
)
return _mcp_error(
f"Your call to {tool_name} had empty arguments — "
f"this means your previous response was too long and "
f"the tool call input was truncated by the API. "
f"To fix this: break your work into smaller steps. "
f"For large content, first write it to a file using "
f"bash_exec with cat >> (append section by section), "
f"then pass it via @@agptfile:filename reference. "
f"Do NOT retry with the same approach — it will "
f"be truncated again."
)
# Circuit breaker: stop infinite retry loops with identical args.
# Use the original (pre-expansion) args for fingerprinting so
# check and record always use the same key — @@agptfile:
@@ -718,24 +547,35 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
schema = _build_input_schema(base_tool)
# All tools annotated readOnlyHint=True to enable parallel dispatch.
# The SDK CLI uses this hint to dispatch concurrent tool calls in
# parallel rather than sequentially. Side-effect safety is ensured
# by the tool implementations themselves (idempotency, credentials).
decorated = tool(
tool_name,
base_tool.description,
schema,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(handler, tool_name, input_schema=schema))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(name, desc, schema)(_truncating(handler, name))
decorated = tool(
name,
desc,
schema,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(handler, name))
sdk_tools.append(decorated)
# Read tool for SDK-truncated tool results (always needed).
# Read tool for SDK-truncated tool results (always needed, read-only).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)
@@ -750,13 +590,14 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
# Security hooks validate that file paths stay within sdk_cwd.
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
# which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks).
# Task/Agent allows spawning sub-agents (rate-limited by security hooks).
# The CLI renamed "Task" → "Agent" in v2.x; both are listed for compat.
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
# access. read_file also handles local tool-results and ephemeral reads.
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
# SDK built-in tools that must be explicitly blocked.

View File

@@ -1,22 +1,21 @@
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
"""Tests for tool_adapter: truncation, stash, context vars, readOnlyHint annotations."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
from mcp.types import ToolAnnotations
from backend.copilot.context import get_sdk_cwd
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.sdk.file_ref import FileRefExpansionError
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
SDK_DISALLOWED_TOOLS,
_text_from_mcp_result,
cancel_pending_tool_tasks,
create_tool_handler,
pop_pending_tool_output,
pre_launch_tool_call,
reset_stash_event,
set_execution_context,
stash_pending_tool_output,
@@ -244,7 +243,7 @@ class TestTruncationAndStashIntegration:
# ---------------------------------------------------------------------------
# Parallel pre-launch infrastructure
# create_tool_handler (direct execution, no pre-launch)
# ---------------------------------------------------------------------------
@@ -277,169 +276,18 @@ def _init_ctx(session=None):
)
class TestPreLaunchToolCall:
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
class TestCreateToolHandler:
"""Tests for create_tool_handler — direct tool execution."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_unknown_tool_is_silently_ignored(self):
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
# Should not raise even if the tool name is completely unknown
await pre_launch_tool_call("nonexistent_tool", {})
@pytest.mark.asyncio
async def test_mcp_prefix_stripped_before_registry_lookup(self):
"""mcp__copilot__run_block is looked up as 'run_block'."""
mock_tool = _make_mock_tool("run_block")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
# The task was enqueued — mock_tool.execute should be called once
# (may not complete immediately but should start)
await asyncio.sleep(0) # yield to event loop
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_bare_tool_name_without_prefix(self):
"""Tool names without __ separator are looked up as-is."""
mock_tool = _make_mock_tool("run_block")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_task_enqueued_fifo_for_same_tool(self):
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
results = []
async def slow_execute(*args, **kwargs):
results.append(len(results))
return StreamToolOutputAvailable(
toolCallId="id",
output=str(len(results) - 1),
toolName="t",
success=True,
)
mock_tool = _make_mock_tool("t")
mock_tool.execute = AsyncMock(side_effect=slow_execute)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"t": mock_tool},
):
await pre_launch_tool_call("t", {"n": 1})
await pre_launch_tool_call("t", {"n": 2})
await asyncio.sleep(0)
assert mock_tool.execute.await_count == 2
@pytest.mark.asyncio
async def test_file_ref_expansion_failure_skips_pre_launch(self):
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
The handler should then fall back to direct execution (which will also
fail with a proper MCP error via _truncating's own expansion).
"""
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
with (
patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
),
patch(
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
),
):
# Should not raise — expansion failure is handled gracefully
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
await asyncio.sleep(0)
# No task was pre-launched — execute was not called
mock_tool.execute.assert_not_awaited()
class TestCreateToolHandlerParallel:
"""Tests for create_tool_handler using pre-launched tasks."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_uses_prelaunched_task(self):
"""Handler pops and awaits the pre-launched task rather than re-executing."""
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0) # let task start
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
text = result["content"][0]["text"]
assert "pre-launched result" in text
# Should only have been called once (the pre-launched task), not twice
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
"""Pre-launched task result must NOT be stashed by tool_handler directly.
The _truncating wrapper wraps tool_handler and handles stashing after
tool_handler returns. If tool_handler also stashed, the output would be
appended twice to the FIFO queue and pop_pending_tool_output would return
a duplicate on the second call.
This test calls tool_handler directly (without _truncating) and asserts
that nothing was stashed — confirming stashing is deferred to _truncating.
"""
mock_tool = _make_mock_tool("run_block", output="stash-me")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
assert "stash-me" in result["content"][0]["text"]
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
# Calling pop here (without going through _truncating) should return None.
not_stashed = pop_pending_tool_output("run_block")
assert not_stashed is None, (
"tool_handler must not stash directly — _truncating handles stashing "
"to prevent double-stash in the FIFO queue"
)
@pytest.mark.asyncio
async def test_handler_falls_back_when_queue_empty(self):
"""When no pre-launched task exists, handler executes directly."""
async def test_handler_executes_tool_directly(self):
"""Handler executes the tool and returns MCP-formatted result."""
mock_tool = _make_mock_tool("run_block", output="direct result")
# Don't call pre_launch_tool_call — queue is empty
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
@@ -449,104 +297,9 @@ class TestCreateToolHandlerParallel:
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_handler_cancelled_error_propagates(self):
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
async def test_handler_returns_error_on_no_session(self):
"""When session is None, handler returns MCP error."""
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
with pytest.raises(asyncio.CancelledError):
await handler({"block_id": "b1"})
@pytest.mark.asyncio
async def test_handler_exception_returns_mcp_error(self):
"""Exception from a pre-launched task is caught and returned as MCP error."""
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is True
assert "Failed to execute run_block" in result["content"][0]["text"]
@pytest.mark.asyncio
async def test_two_same_tool_calls_dispatched_in_order(self):
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
call_order = []
async def execute_with_tag(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_order.append(tag)
return StreamToolOutputAvailable(
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "first"})
await pre_launch_tool_call("run_block", {"block_id": "second"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
r1 = await handler({"block_id": "first"})
r2 = await handler({"block_id": "second"})
assert "out-first" in r1["content"][0]["text"]
assert "out-second" in r2["content"][0]["text"]
assert call_order == [
"first",
"second",
], f"Expected FIFO dispatch order but got {call_order}"
@pytest.mark.asyncio
async def test_arg_mismatch_falls_back_to_direct_execution(self):
"""When pre-launched args differ from SDK args, handler cancels pre-launched
task and falls back to direct execution with the correct args."""
mock_tool = _make_mock_tool("run_block", output="direct-result")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
# Pre-launch with args {"block_id": "wrong"}
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
await asyncio.sleep(0)
# SDK dispatches with different args
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "correct"})
assert result["isError"] is False
# The tool was called twice: once by pre-launch (wrong args), once by
# direct fallback (correct args). The result should come from the
# direct execution path.
assert mock_tool.execute.await_count == 2
@pytest.mark.asyncio
async def test_no_session_falls_back_gracefully(self):
"""When session is None and no pre-launched task, handler returns MCP error."""
mock_tool = _make_mock_tool("run_block")
# session=None means get_execution_context returns (user_id, None)
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
handler = create_tool_handler(mock_tool)
@@ -555,220 +308,406 @@ class TestCreateToolHandlerParallel:
assert result["isError"] is True
assert "session" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# cancel_pending_tool_tasks
# ---------------------------------------------------------------------------
class TestCancelPendingToolTasks:
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_cancels_queued_tasks(self):
"""Queued tasks are cancelled and the queue is cleared."""
ran = False
async def never_run(*_args, **_kwargs):
nonlocal ran
await asyncio.sleep(10) # long enough to still be pending
ran = True
async def test_handler_returns_error_on_exception(self):
"""Exception from tool execution is caught and returned as MCP error."""
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=never_run)
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0) # let task start
await cancel_pending_tool_tasks()
await asyncio.sleep(0) # let cancellation propagate
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert not ran, "Task should have been cancelled before completing"
assert result["isError"] is True
assert "Failed to execute run_block" in result["content"][0]["text"]
@pytest.mark.asyncio
async def test_noop_when_no_tasks_queued(self):
"""cancel_pending_tool_tasks does not raise when queues are empty."""
await cancel_pending_tool_tasks() # should not raise
async def test_handler_executes_once_per_call(self):
"""Each handler call executes the tool exactly once — no duplicate execution."""
mock_tool = _make_mock_tool("run_block", output="single-execution")
@pytest.mark.asyncio
async def test_handler_does_not_find_cancelled_task(self):
"""After cancel, tool_handler falls back to direct execution."""
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1"})
await handler({"block_id": "b2"})
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
await cancel_pending_tool_tasks()
# Queue is now empty — handler should execute directly
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
assert "direct-fallback" in result["content"][0]["text"]
# ---------------------------------------------------------------------------
# Concurrent / parallel pre-launch scenarios
# ---------------------------------------------------------------------------
class TestAllParallelToolsPrelaunchedIndependently:
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_all_parallel_tools_prelaunched_independently(self):
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
roughly PER_TASK_S (plus scheduling overhead).
"""
PER_TASK_S = 0.05
N = 5
started: list[int] = []
finished: list[int] = []
async def slow_execute(*args, **kwargs):
idx = len(started)
started.append(idx)
await asyncio.sleep(PER_TASK_S)
finished.append(idx)
return StreamToolOutputAvailable(
toolCallId=f"id-{idx}",
output=f"result-{idx}",
toolName="bash_exec",
success=True,
)
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=slow_execute)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
for i in range(N):
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
# Measure only the concurrent execution window, not pre-launch overhead.
# Starting the timer here avoids false failures on slow CI runners where
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
t0 = asyncio.get_running_loop().time()
await asyncio.sleep(PER_TASK_S * 2)
elapsed = asyncio.get_running_loop().time() - t0
assert mock_tool.execute.await_count == N
assert len(finished) == N
# Wall time of the sleep window should be well under N * PER_TASK_S
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
assert elapsed < N * PER_TASK_S, (
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
f"but sleep window took {elapsed:.2f}s"
)
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
"""Pop pre-launched tasks in order and verify each returns its own result."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_returns_result_from_correct_prelaunched_task(self):
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
async def execute_with_cmd(*args, **kwargs):
cmd = kwargs.get("cmd", "?")
return StreamToolOutputAvailable(
toolCallId="id",
output=f"output-for-{cmd}",
toolName="bash_exec",
success=True,
)
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
await asyncio.sleep(0) # let both tasks start
handler = create_tool_handler(mock_tool)
r1 = await handler({"cmd": "alpha"})
r2 = await handler({"cmd": "beta"})
text1 = r1["content"][0]["text"]
text2 = r2["content"][0]["text"]
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
assert mock_tool.execute.await_count == 2
class TestFiveConcurrentPrelaunchAllComplete:
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
# ---------------------------------------------------------------------------
# Regression tests: bugs fixed by removing pre-launch mechanism
#
# Each test class includes a _buggy_handler fixture that reproduces the old
# pre-launch implementation inline. Tests run against BOTH the buggy handler
# (xfail — proves the bug exists) and the current clean handler (must pass).
# ---------------------------------------------------------------------------
def _make_execute_fn(tool_name: str = "run_block"):
"""Return (execute_fn, call_log) — execute_fn records every call."""
call_log: list[dict] = []
async def execute_fn(*args, **kwargs):
call_log.append(kwargs)
return StreamToolOutputAvailable(
toolCallId=f"id-{len(call_log)}",
output=f"result-{len(call_log)}",
toolName=tool_name,
success=True,
)
return execute_fn, call_log
async def _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args):
"""Simulate the OLD buggy pre-launch flow.
1. pre_launch_tool_call fires _execute_tool_sync with pre_launch_args
2. SDK dispatches handler with dispatch_args
3. Handler compares args — on mismatch, cancels + re-executes (BUG)
Returns the handler result.
"""
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
user_id, session = "user-1", _make_mock_session()
# Step 1: pre-launch fires immediately (speculative)
task = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, pre_launch_args)
)
await asyncio.sleep(0) # let task start
# Step 2: SDK dispatches with (potentially different) args
if pre_launch_args != dispatch_args:
# Arg mismatch path: cancel pre-launched task + re-execute
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
# Fall through to direct execution (duplicate!)
return await _execute_tool_sync(mock_tool, user_id, session, dispatch_args)
else:
return await task
class TestBug1DuplicateExecution:
"""Bug 1 (SECRT-2204): arg mismatch causes duplicate execution.
Pre-launch fires with raw args, SDK dispatches with normalised args.
Mismatch → cancel (too late) + re-execute → 2 API calls.
"""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old pre-launch code causes duplicate execution")
@pytest.mark.asyncio
async def test_five_concurrent_prelaunch_all_complete(self):
"""All 5 pre-launched tasks complete and return successful results."""
N = 5
call_count = 0
async def test_old_code_duplicates_on_arg_mismatch(self):
"""OLD CODE: pre-launch with args A, dispatch with args B → 2 calls."""
execute_fn, call_log = _make_execute_fn()
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_fn)
async def counting_execute(*args, **kwargs):
nonlocal call_count
call_count += 1
n = call_count
pre_launch_args = {"block_id": "b1", "input_data": {"title": "Test"}}
dispatch_args = {
"block_id": "b1",
"input_data": {"title": "Test", "priority": None},
}
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
# BUG: pre-launch executed once + fallback executed again = 2
assert len(call_log) == 1, (
f"Expected 1 execution but got {len(call_log)}"
f"duplicate execution bug!"
)
@pytest.mark.asyncio
async def test_current_code_no_duplicate(self):
"""FIXED: handler executes exactly once regardless of arg shape."""
execute_fn, call_log = _make_execute_fn()
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_fn)
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1", "input_data": {"title": "Test"}})
assert len(call_log) == 1, f"Expected 1 execution but got {len(call_log)}"
class TestBug2FIFODesync:
"""Bug 2: FIFO desync when security hook denies a tool.
Pre-launch queues [task_A, task_B]. Tool A denied (no MCP dispatch).
Tool B's handler dequeues task_A → returns wrong result.
"""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old FIFO queue returns wrong result on denial")
@pytest.mark.asyncio
async def test_old_code_fifo_desync_on_denial(self):
"""OLD CODE: denied tool's task stays in queue, next tool gets wrong result."""
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
call_log: list[str] = []
async def tagged_execute(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_log.append(tag)
return StreamToolOutputAvailable(
toolCallId=f"id-{n}",
output=f"done-{n}",
toolName="bash_exec",
toolCallId="id",
output=f"result-for-{tag}",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=counting_execute)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
user_id, session = "user-1", _make_mock_session()
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
for i in range(N):
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
# Simulate old FIFO queue
queue: asyncio.Queue = asyncio.Queue()
await asyncio.sleep(0) # let all tasks start
# Pre-launch for tool A and tool B
task_a = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "A"})
)
task_b = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "B"})
)
queue.put_nowait(task_a)
queue.put_nowait(task_b)
await asyncio.sleep(0) # let both tasks run
handler = create_tool_handler(mock_tool)
results = []
for i in range(N):
results.append(await handler({"cmd": f"task-{i}"}))
# Tool A is DENIED by security hook — no MCP dispatch, no dequeue
# Tool B's handler dequeues from FIFO → gets task_A!
dequeued_task = queue.get_nowait()
result = await dequeued_task
result_text = result["content"][0]["text"]
assert (
mock_tool.execute.await_count == N
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
for i, result in enumerate(results):
assert result["isError"] is False, f"Result {i} should not be an error"
text = result["content"][0]["text"]
assert "done-" in text, f"Result {i} missing expected output: {text}"
# BUG: handler for B got task_A's result
assert "result-for-B" in result_text, (
f"Expected result for B but got: {result_text}"
f"FIFO desync: B got A's result!"
)
@pytest.mark.asyncio
async def test_current_code_no_fifo_desync(self):
"""FIXED: each handler call executes independently, no shared queue."""
call_log: list[str] = []
async def tagged_execute(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_log.append(tag)
return StreamToolOutputAvailable(
toolCallId="id",
output=f"result-for-{tag}",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
handler = create_tool_handler(mock_tool)
# Tool A denied (never called). Tool B dispatched normally.
result_b = await handler({"block_id": "B"})
assert "result-for-B" in result_b["content"][0]["text"]
assert call_log == ["B"]
class TestBug3CancelRace:
"""Bug 3: cancel race — task completes before cancel arrives.
Pre-launch fires fast HTTP call (< 1s). By the time handler detects
mismatch and calls task.cancel(), the API call already completed.
Side effect (Linear issue created) is irreversible.
"""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old code: cancel arrives after task completes")
@pytest.mark.asyncio
async def test_old_code_cancel_arrives_too_late(self):
"""OLD CODE: fast task completes before cancel, side effect persists."""
side_effects: list[str] = []
async def fast_execute_with_side_effect(*args, **kwargs):
# Side effect happens immediately (like an HTTP POST to Linear)
side_effects.append("created-issue")
return StreamToolOutputAvailable(
toolCallId="id",
output="issue-created",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=fast_execute_with_side_effect)
# Pre-launch fires immediately
pre_launch_args = {"block_id": "b1"}
dispatch_args = {"block_id": "b1", "extra": "normalised"}
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
# BUG: side effect happened TWICE (pre-launch + fallback)
assert len(side_effects) == 1, (
f"Expected 1 side effect but got {len(side_effects)}"
f"cancel race: pre-launch completed before cancel!"
)
@pytest.mark.asyncio
async def test_current_code_single_side_effect(self):
"""FIXED: no speculative execution, exactly 1 side effect per call."""
side_effects: list[str] = []
async def execute_with_side_effect(*args, **kwargs):
side_effects.append("created-issue")
return StreamToolOutputAvailable(
toolCallId="id",
output="issue-created",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_with_side_effect)
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1"})
assert len(side_effects) == 1
# ---------------------------------------------------------------------------
# readOnlyHint annotations
# ---------------------------------------------------------------------------
class TestReadOnlyAnnotations:
"""Tests that all tools get readOnlyHint=True for parallel dispatch."""
def test_parallel_annotation_constant(self):
"""_PARALLEL_ANNOTATION is a ToolAnnotations with readOnlyHint=True."""
from .tool_adapter import _PARALLEL_ANNOTATION
assert isinstance(_PARALLEL_ANNOTATION, ToolAnnotations)
assert _PARALLEL_ANNOTATION.readOnlyHint is True
# ---------------------------------------------------------------------------
# SDK_DISALLOWED_TOOLS
# ---------------------------------------------------------------------------
class TestSDKDisallowedTools:
"""Verify that dangerous SDK built-in tools are in the disallowed list."""
def test_bash_tool_is_disallowed(self):
assert "Bash" in SDK_DISALLOWED_TOOLS
def test_webfetch_tool_is_disallowed(self):
"""WebFetch is disallowed due to SSRF risk."""
assert "WebFetch" in SDK_DISALLOWED_TOOLS
# ---------------------------------------------------------------------------
# _read_file_handler — bridge_and_annotate integration
# ---------------------------------------------------------------------------
class TestReadFileHandlerBridge:
"""Verify that _read_file_handler calls bridge_and_annotate when a sandbox is active."""
@pytest.fixture(autouse=True)
def _init_context(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/copilot-bridge-test",
)
@pytest.mark.asyncio
async def test_bridge_called_when_sandbox_active(self, tmp_path, monkeypatch):
"""When a sandbox is set, bridge_and_annotate is called and its annotation appended."""
from backend.copilot.context import _current_sandbox
from .tool_adapter import _read_file_handler
test_file = tmp_path / "tool-results" / "data.json"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text('{"ok": true}\n')
monkeypatch.setattr(
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
lambda path, cwd: True,
)
fake_sandbox = object()
token = _current_sandbox.set(fake_sandbox) # type: ignore[arg-type]
try:
bridge_calls: list[tuple] = []
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
bridge_calls.append((sandbox, file_path, offset, limit))
return "\n[Sandbox copy available at /tmp/abc-data.json]"
monkeypatch.setattr(
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
fake_bridge_and_annotate,
)
result = await _read_file_handler(
{"file_path": str(test_file), "offset": 0, "limit": 2000}
)
assert result["isError"] is False
assert len(bridge_calls) == 1
assert bridge_calls[0][0] is fake_sandbox
assert "/tmp/abc-data.json" in result["content"][0]["text"]
finally:
_current_sandbox.reset(token)
@pytest.mark.asyncio
async def test_bridge_not_called_without_sandbox(self, tmp_path, monkeypatch):
"""When no sandbox is set, bridge_and_annotate is not called."""
from .tool_adapter import _read_file_handler
test_file = tmp_path / "tool-results" / "data.json"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text('{"ok": true}\n')
monkeypatch.setattr(
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
lambda path, cwd: True,
)
bridge_calls: list[tuple] = []
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
bridge_calls.append((sandbox, file_path, offset, limit))
return "\n[Sandbox copy available at /tmp/abc-data.json]"
monkeypatch.setattr(
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
fake_bridge_and_annotate,
)
result = await _read_file_handler(
{"file_path": str(test_file), "offset": 0, "limit": 2000}
)
assert result["isError"] is False
assert len(bridge_calls) == 0
assert "Sandbox copy" not in result["content"][0]["text"]

View File

@@ -43,6 +43,10 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# Thinking block types that can be stripped from non-last assistant entries.
# The Anthropic API only requires these in the *last* assistant message.
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
@dataclass
class TranscriptDownload:
@@ -450,6 +454,83 @@ def _build_meta_storage_path(user_id: str, session_id: str, backend: object) ->
)
def strip_stale_thinking_blocks(content: str) -> str:
"""Remove thinking/redacted_thinking blocks from non-last assistant entries.
The Anthropic API only requires thinking blocks in the **last** assistant
message to be value-identical to the original response. Older assistant
entries carry stale thinking blocks that consume significant tokens
(often 10-50K each) without providing useful context for ``--resume``.
Stripping them before upload prevents the CLI from triggering compaction
every turn just to compress away the stale thinking bloat.
"""
lines = content.strip().split("\n")
if not lines:
return content
parsed: list[tuple[str, dict | None]] = []
for line in lines:
parsed.append((line, json.loads(line, fallback=None)))
# Reverse scan to find the last assistant message ID and index.
last_asst_msg_id: str | None = None
last_asst_idx: int | None = None
for i in range(len(parsed) - 1, -1, -1):
_line, entry = parsed[i]
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_asst_msg_id = msg.get("id")
last_asst_idx = i
break
if last_asst_idx is None:
return content
result_lines: list[str] = []
stripped_count = 0
for i, (line, entry) in enumerate(parsed):
if not isinstance(entry, dict):
result_lines.append(line)
continue
msg = entry.get("message", {})
# Only strip from assistant entries that are NOT the last turn.
# Use msg_id matching when available; fall back to index for entries
# without an id field.
is_last_turn = (
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
) or (last_asst_msg_id is None and i == last_asst_idx)
if (
msg.get("role") == "assistant"
and not is_last_turn
and isinstance(msg.get("content"), list)
):
content_blocks = msg["content"]
filtered = [
b
for b in content_blocks
if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES)
]
if len(filtered) < len(content_blocks):
stripped_count += len(content_blocks) - len(filtered)
entry = {**entry, "message": {**msg, "content": filtered}}
result_lines.append(json.dumps(entry, separators=(",", ":")))
continue
result_lines.append(line)
if stripped_count:
logger.info(
"[Transcript] Stripped %d stale thinking block(s) from non-last entries",
stripped_count,
)
return "\n".join(result_lines) + "\n"
async def upload_transcript(
user_id: str,
session_id: str,
@@ -472,6 +553,9 @@ async def upload_transcript(
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
# Strip stale thinking blocks from older assistant entries — these consume
# significant tokens and trigger unnecessary CLI compaction every turn.
stripped = strip_stale_thinking_blocks(stripped)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
@@ -605,9 +689,6 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
@@ -633,11 +714,14 @@ def _flatten_assistant_content(blocks: list) -> str:
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
# Drop tool_use entirely — any text representation gets
# mimicked by the model as plain text instead of actual
# structured tool calls. The tool results (in the
# following user/tool_result entry) provide sufficient
# context about what happened.
continue
else:
# Preserve non-text blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
parts.append(f"[__{btype}__]")
continue
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""

View File

@@ -13,6 +13,7 @@ from .transcript import (
delete_transcript,
read_compacted_entries,
strip_progress_entries,
strip_stale_thinking_blocks,
validate_transcript,
write_transcript_to_tempfile,
)
@@ -1200,3 +1201,170 @@ class TestCleanupStaleProjectDirs:
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
assert removed == 0
assert non_copilot.exists()
# ---------------------------------------------------------------------------
# strip_stale_thinking_blocks
# ---------------------------------------------------------------------------
class TestStripStaleThinkingBlocks:
"""Tests for strip_stale_thinking_blocks — removes thinking/redacted_thinking
blocks from non-last assistant entries to reduce transcript bloat."""
def _asst_entry(
self, msg_id: str, content: list, uuid: str = "u1", parent: str = ""
) -> dict:
return {
"type": "assistant",
"uuid": uuid,
"parentUuid": parent,
"message": {
"role": "assistant",
"id": msg_id,
"type": "message",
"content": content,
},
}
def _user_entry(self, text: str, uuid: str = "u0", parent: str = "") -> dict:
return {
"type": "user",
"uuid": uuid,
"parentUuid": parent,
"message": {"role": "user", "content": text},
}
def test_strips_thinking_from_older_assistant(self) -> None:
"""Thinking blocks in non-last assistant entries should be removed."""
old_asst = self._asst_entry(
"msg_old",
[
{"type": "thinking", "thinking": "deep thoughts..."},
{"type": "text", "text": "hello"},
{"type": "redacted_thinking", "data": "secret"},
],
uuid="a1",
)
new_asst = self._asst_entry(
"msg_new",
[
{"type": "thinking", "thinking": "latest thoughts"},
{"type": "text", "text": "world"},
],
uuid="a2",
parent="a1",
)
content = _make_jsonl(old_asst, new_asst)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# Old assistant should have thinking blocks stripped
old_content = lines[0]["message"]["content"]
assert len(old_content) == 1
assert old_content[0]["type"] == "text"
# New (last) assistant should be untouched
new_content = lines[1]["message"]["content"]
assert len(new_content) == 2
assert new_content[0]["type"] == "thinking"
assert new_content[1]["type"] == "text"
def test_preserves_last_assistant_thinking(self) -> None:
"""The last assistant entry's thinking blocks must be preserved."""
entry = self._asst_entry(
"msg_only",
[
{"type": "thinking", "thinking": "must keep"},
{"type": "text", "text": "response"},
],
)
content = _make_jsonl(entry)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(lines[0]["message"]["content"]) == 2
def test_no_assistant_entries_returns_unchanged(self) -> None:
"""Transcripts with only user entries should pass through unchanged."""
user = self._user_entry("hello")
content = _make_jsonl(user)
assert strip_stale_thinking_blocks(content) == content
def test_empty_content_returns_unchanged(self) -> None:
assert strip_stale_thinking_blocks("") == ""
def test_multiple_turns_strips_all_but_last(self) -> None:
"""With 3 assistant turns, only the last keeps thinking blocks."""
entries = [
self._asst_entry(
"msg_1",
[
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "a1"},
],
uuid="a1",
),
self._user_entry("q2", uuid="u2", parent="a1"),
self._asst_entry(
"msg_2",
[
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "a2"},
],
uuid="a2",
parent="u2",
),
self._user_entry("q3", uuid="u3", parent="a2"),
self._asst_entry(
"msg_3",
[
{"type": "thinking", "thinking": "t3"},
{"type": "text", "text": "a3"},
],
uuid="a3",
parent="u3",
),
]
content = _make_jsonl(*entries)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# msg_1: thinking stripped
assert len(lines[0]["message"]["content"]) == 1
assert lines[0]["message"]["content"][0]["type"] == "text"
# msg_2: thinking stripped
assert len(lines[2]["message"]["content"]) == 1
# msg_3 (last): thinking preserved
assert len(lines[4]["message"]["content"]) == 2
assert lines[4]["message"]["content"][0]["type"] == "thinking"
def test_same_msg_id_multi_entry_turn(self) -> None:
"""Multiple entries sharing the same message.id (same turn) are preserved."""
entries = [
self._asst_entry(
"msg_old",
[{"type": "thinking", "thinking": "old"}],
uuid="a1",
),
self._asst_entry(
"msg_last",
[{"type": "thinking", "thinking": "t_part1"}],
uuid="a2",
parent="a1",
),
self._asst_entry(
"msg_last",
[{"type": "text", "text": "response"}],
uuid="a3",
parent="a2",
),
]
content = _make_jsonl(*entries)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# Old entry stripped
assert lines[0]["message"]["content"] == []
# Both entries of last turn (msg_last) preserved
assert lines[1]["message"]["content"][0]["type"] == "thinking"
assert lines[2]["message"]["content"][0]["type"] == "text"

View File

@@ -30,7 +30,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await create_chat_session(test_user_id, dry_run=False)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -221,9 +221,21 @@ async def create_session(
return session
_meta_ttl_refresh_at: dict[str, float] = {}
"""Tracks the last time the session meta key TTL was refreshed.
Used by `publish_chunk` to avoid refreshing on every single chunk
(expensive). Refreshes at most once every 60 seconds per session.
"""
_META_TTL_REFRESH_INTERVAL = 60 # seconds
async def publish_chunk(
turn_id: str,
chunk: StreamBaseResponse,
*,
session_id: str | None = None,
) -> str:
"""Publish a chunk to Redis Stream.
@@ -232,6 +244,9 @@ async def publish_chunk(
Args:
turn_id: Turn ID (per-turn UUID) identifying the stream
chunk: The stream response chunk to publish
session_id: Chat session ID — when provided, the session meta key
TTL is refreshed periodically to prevent expiration during
long-running turns (see SECRT-2178).
Returns:
The Redis Stream message ID
@@ -265,6 +280,23 @@ async def publish_chunk(
# Set TTL on stream to match session metadata TTL
await redis.expire(stream_key, config.stream_ttl)
# Periodically refresh session-related TTLs so they don't expire
# during long-running turns. Without this, turns exceeding stream_ttl
# (default 1h) lose their "running" status and stream data, making
# the session invisible to the resume endpoint (empty on page reload).
# Both meta key AND stream key are refreshed: the stream key's expire
# above only fires when publish_chunk is called, but during long
# sub-agent gaps (task_progress events don't produce chunks), neither
# key gets refreshed.
if session_id:
now = time.perf_counter()
last_refresh = _meta_ttl_refresh_at.get(session_id, 0)
if now - last_refresh >= _META_TTL_REFRESH_INTERVAL:
meta_key = _get_session_meta_key(session_id)
await redis.expire(meta_key, config.stream_ttl)
await redis.expire(stream_key, config.stream_ttl)
_meta_ttl_refresh_at[session_id] = now
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
@@ -331,7 +363,7 @@ async def stream_and_publish(
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event)
await publish_chunk(turn_id, event, session_id=session_id)
except (RedisError, ConnectionError, OSError):
if not publish_failed_once:
publish_failed_once = True
@@ -800,6 +832,9 @@ async def mark_session_completed(
# Atomic compare-and-swap: only update if status is "running"
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
_meta_ttl_refresh_at.pop(session_id, None)
if result == 0:
logger.debug(f"Session {session_id} already completed/failed, skipping")
return False

View File

@@ -10,6 +10,7 @@ from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .ask_question import AskQuestionTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
@@ -55,6 +56,7 @@ logger = logging.getLogger(__name__)
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"ask_question": AskQuestionTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),

View File

@@ -68,6 +68,9 @@ class AddUnderstandingTool(BaseTool):
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
Note: This tool accepts **kwargs because its parameters are derived
dynamically from the BusinessUnderstandingInput model schema.
"""
session_id = session.session_id
@@ -77,23 +80,21 @@ class AddUnderstandingTool(BaseTool):
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
filtered = {k: v for k, v in kwargs.items() if k in valid_fields}
# Check if any data was provided
if not any(v is not None for v in kwargs.values()):
if not any(v is not None for v in filtered.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
input_data = BusinessUnderstandingInput(
**{k: v for k, v in kwargs.items() if k in valid_fields}
)
input_data = BusinessUnderstandingInput(**filtered)
# Track which fields were updated
updated_fields = [
k for k, v in kwargs.items() if k in valid_fields and v is not None
]
updated_fields = [k for k, v in filtered.items() if v is not None]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(

View File

@@ -180,12 +180,14 @@ async def _save_browser_state(
"""
try:
# Gather state in parallel
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
(
(rc_url, url_out, _),
(rc_ck, ck_out, _),
(rc_ls, ls_out, _),
) = await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
state = {
@@ -448,6 +450,8 @@ class BrowserNavigateTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
url: str = "",
wait_for: str = "networkidle",
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
@@ -456,8 +460,8 @@ class BrowserNavigateTool(BaseTool):
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
url = url.strip()
wait_for = wait_for or "networkidle"
session_name = session.session_id
if not url:
@@ -612,6 +616,10 @@ class BrowserActTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
action: str = "",
target: str = "",
value: str = "",
direction: str = "down",
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
@@ -620,10 +628,10 @@ class BrowserActTool(BaseTool):
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
action = action.strip()
target = target.strip()
value = value.strip()
direction = direction.strip()
session_name = session.session_id
if not action:
@@ -777,6 +785,8 @@ class BrowserScreenshotTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
annotate: bool | str = True,
filename: str = "screenshot.png",
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
@@ -786,12 +796,12 @@ class BrowserScreenshotTool(BaseTool):
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = kwargs.get("annotate", True)
raw_annotate = annotate
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
filename = filename.strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod

View File

@@ -411,7 +411,12 @@ class AgentOutputTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool."""
"""Execute the agent_output tool.
Note: This tool accepts **kwargs and delegates to AgentOutputInput
for validation because the parameter set has cross-field validators
defined in the Pydantic model.
"""
session_id = session.session_id
# Parse and validate input

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Literal
@@ -9,7 +10,7 @@ if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
from backend.data.db_accessors import library_db, store_db
from backend.data.db_accessors import graph_db, library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
@@ -34,12 +35,13 @@ async def search_agents(
source: SearchSource,
session_id: str | None = None,
user_id: str | None = None,
include_graph: bool = False,
) -> ToolResponseBase:
"""Search for agents in marketplace or user library."""
if source == "marketplace":
return await _search_marketplace(query, session_id)
else:
return await _search_library(query, session_id, user_id)
return await _search_library(query, session_id, user_id, include_graph)
async def _search_marketplace(query: str, session_id: str | None) -> ToolResponseBase:
@@ -105,7 +107,10 @@ async def _search_marketplace(query: str, session_id: str | None) -> ToolRespons
async def _search_library(
query: str, session_id: str | None, user_id: str | None
query: str,
session_id: str | None,
user_id: str | None,
include_graph: bool = False,
) -> ToolResponseBase:
"""Search user's library agents, with direct UUID lookup fallback."""
if not user_id:
@@ -149,6 +154,10 @@ async def _search_library(
session_id=session_id,
)
truncation_notice: str | None = None
if include_graph and agents:
truncation_notice = await _enrich_agents_with_graph(agents, user_id)
if not agents:
if not query:
return NoResultsResponse(
@@ -182,13 +191,17 @@ async def _search_library(
else:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
message = (
"Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
)
if truncation_notice:
message = f"{message}\n\nNote: {truncation_notice}"
return AgentsFoundResponse(
message=(
"Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
),
message=message,
title=title,
agents=agents,
count=len(agents),
@@ -196,6 +209,81 @@ async def _search_library(
)
_MAX_GRAPH_FETCHES = 10
_GRAPH_FETCH_TIMEOUT = 15 # seconds
async def _enrich_agents_with_graph(
agents: list[AgentInfo], user_id: str
) -> str | None:
"""Fetch and attach full Graph (nodes + links) to each agent in-place.
Only the first ``_MAX_GRAPH_FETCHES`` agents with a ``graph_id`` are
enriched. If some agents are skipped, a truncation notice is returned
so the caller can surface it to the copilot.
Graphs are fetched with ``for_export=True`` so that credentials, API keys,
and other secrets in ``input_default`` are stripped before the data reaches
the LLM context.
Returns a truncation notice string when some agents were skipped, or
``None`` when all eligible agents were enriched.
"""
with_graph_id = [a for a in agents if a.graph_id]
fetchable = with_graph_id[:_MAX_GRAPH_FETCHES]
if not fetchable:
return None
gdb = graph_db()
async def _fetch(agent: AgentInfo) -> None:
graph_id = agent.graph_id
if not graph_id:
return
try:
graph = await gdb.get_graph(
graph_id,
version=agent.graph_version,
user_id=user_id,
for_export=True,
)
if graph is None:
logger.warning("Graph not found for agent %s", graph_id)
agent.graph = graph
except Exception as e:
logger.warning("Failed to fetch graph for agent %s: %s", graph_id, e)
try:
await asyncio.wait_for(
asyncio.gather(*[_fetch(a) for a in fetchable]),
timeout=_GRAPH_FETCH_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
"include_graph: timed out after %ds fetching graphs", _GRAPH_FETCH_TIMEOUT
)
skipped = len(with_graph_id) - len(fetchable)
if skipped > 0:
logger.warning(
"include_graph: fetched graphs for %d/%d agents "
"(_MAX_GRAPH_FETCHES=%d, %d skipped)",
len(fetchable),
len(with_graph_id),
_MAX_GRAPH_FETCHES,
skipped,
)
return (
f"Graph data included for {len(fetchable)} of "
f"{len(with_graph_id)} eligible agents (limit: {_MAX_GRAPH_FETCHES}). "
f"To fetch graphs for remaining agents, narrow your search to a "
f"specific agent by UUID."
)
return None
def _marketplace_agent_to_info(agent: StoreAgent | StoreAgentDetails) -> AgentInfo:
"""Convert a marketplace agent (StoreAgent or StoreAgentDetails) to an AgentInfo."""
return AgentInfo(

View File

@@ -1,11 +1,12 @@
"""Tests for agent search direct lookup functionality."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .agent_search import search_agents
from .models import AgentsFoundResponse, NoResultsResponse
from .agent_search import _enrich_agents_with_graph, search_agents
from .models import AgentInfo, AgentsFoundResponse, NoResultsResponse
_TEST_USER_ID = "test-user-agent-search"
@@ -133,10 +134,10 @@ class TestMarketplaceSlugLookup:
class TestLibraryUUIDLookup:
"""Tests for UUID direct lookup in library search."""
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_found_by_graph_id(self):
"""UUID query matching a graph_id returns the agent directly."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
@staticmethod
def _make_mock_library_agent(
agent_id: str = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
) -> MagicMock:
mock_agent = MagicMock()
mock_agent.id = "lib-agent-id"
mock_agent.name = "My Library Agent"
@@ -150,6 +151,13 @@ class TestLibraryUUIDLookup:
mock_agent.graph_version = 1
mock_agent.input_schema = {}
mock_agent.output_schema = {}
return mock_agent
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_found_by_graph_id(self):
"""UUID query matching a graph_id returns the agent directly."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
@@ -168,3 +176,427 @@ class TestLibraryUUIDLookup:
assert isinstance(response, AgentsFoundResponse)
assert response.count == 1
assert response.agents[0].name == "My Library Agent"
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_fetches_graph(self):
"""include_graph=True attaches BaseGraph to agent results."""
from backend.data.graph import BaseGraph
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
fake_graph = BaseGraph(id=agent_id, name="My Library Agent", description="test")
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(return_value=fake_graph)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is not None
assert response.agents[0].graph.id == agent_id
mock_graph_db.get_graph.assert_awaited_once_with(
agent_id,
version=1,
user_id=_TEST_USER_ID,
for_export=True,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_false_skips_fetch(self):
"""include_graph=False (default) does not fetch graph data."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock()
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=False,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is None
mock_graph_db.get_graph.assert_not_awaited()
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_handles_fetch_failure(self):
"""include_graph=True still returns agents when graph fetch fails."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(side_effect=Exception("DB down"))
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is None
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_handles_none_return(self):
"""include_graph=True handles get_graph returning None."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is None
class TestEnrichAgentsWithGraph:
"""Tests for _enrich_agents_with_graph edge cases."""
@staticmethod
def _make_mock_library_agent(
agent_id: str = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
graph_id: str | None = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
) -> MagicMock:
mock_agent = MagicMock()
mock_agent.id = f"lib-{agent_id[:8]}"
mock_agent.name = f"Agent {agent_id[:8]}"
mock_agent.description = "A library agent"
mock_agent.creator_name = "testuser"
mock_agent.status.value = "HEALTHY"
mock_agent.can_access_graph = True
mock_agent.has_external_trigger = False
mock_agent.new_output = False
mock_agent.graph_id = graph_id
mock_agent.graph_version = 1
mock_agent.input_schema = {}
mock_agent.output_schema = {}
return mock_agent
@pytest.mark.asyncio(loop_scope="session")
async def test_truncation_surfaces_in_response(self):
"""When >_MAX_GRAPH_FETCHES agents have graphs, the response contains a truncation notice."""
from backend.copilot.tools.agent_search import _MAX_GRAPH_FETCHES
from backend.data.graph import BaseGraph
agent_count = _MAX_GRAPH_FETCHES + 5
mock_agents = []
for i in range(agent_count):
uid = f"a1b2c3d4-e5f6-4a7b-8c9d-{i:012d}"
mock_agents.append(self._make_mock_library_agent(uid, uid))
mock_lib_db = MagicMock()
mock_search_results = MagicMock()
mock_search_results.agents = mock_agents
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
fake_graph = BaseGraph(id="x", name="g", description="d")
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_gdb,
),
):
response = await search_agents(
query="",
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert mock_gdb.get_graph.await_count == _MAX_GRAPH_FETCHES
enriched = [a for a in response.agents if a.graph is not None]
assert len(enriched) == _MAX_GRAPH_FETCHES
assert "Graph data included for" in response.message
assert str(_MAX_GRAPH_FETCHES) in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_mixed_graph_id_presence(self):
"""Agents without graph_id are skipped during enrichment."""
from backend.data.graph import BaseGraph
agent_with = self._make_mock_library_agent(
"aaaa0000-0000-0000-0000-000000000001",
"aaaa0000-0000-0000-0000-000000000001",
)
agent_without = self._make_mock_library_agent(
"bbbb0000-0000-0000-0000-000000000002",
graph_id=None,
)
mock_lib_db = MagicMock()
mock_search_results = MagicMock()
mock_search_results.agents = [agent_with, agent_without]
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
fake_graph = BaseGraph(
id="aaaa0000-0000-0000-0000-000000000001", name="g", description="d"
)
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_gdb,
),
):
response = await search_agents(
query="",
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert len(response.agents) == 2
assert response.agents[0].graph is not None
assert response.agents[1].graph is None
mock_gdb.get_graph.assert_awaited_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_partial_failure_across_multiple_agents(self):
"""When some graph fetches fail, successful ones still have graphs attached."""
from backend.data.graph import BaseGraph
id_ok = "aaaa0000-0000-0000-0000-000000000001"
id_fail = "bbbb0000-0000-0000-0000-000000000002"
agent_ok = self._make_mock_library_agent(id_ok, id_ok)
agent_fail = self._make_mock_library_agent(id_fail, id_fail)
mock_lib_db = MagicMock()
mock_search_results = MagicMock()
mock_search_results.agents = [agent_ok, agent_fail]
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
fake_graph = BaseGraph(id=id_ok, name="g", description="d")
async def _side_effect(graph_id, **kwargs):
if graph_id == id_fail:
raise Exception("DB error")
return fake_graph
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(side_effect=_side_effect)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_gdb,
),
):
response = await search_agents(
query="",
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is not None
assert response.agents[0].graph.id == id_ok
assert response.agents[1].graph is None
@pytest.mark.asyncio(loop_scope="session")
async def test_keyword_search_with_include_graph(self):
"""include_graph works via keyword search (non-UUID path)."""
from backend.data.graph import BaseGraph
agent_id = "cccc0000-0000-0000-0000-000000000003"
mock_agent = self._make_mock_library_agent(agent_id, agent_id)
mock_lib_db = MagicMock()
mock_search_results = MagicMock()
mock_search_results.agents = [mock_agent]
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
fake_graph = BaseGraph(id=agent_id, name="g", description="d")
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.graph_db",
return_value=mock_gdb,
),
):
response = await search_agents(
query="email",
source="library",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is not None
assert response.agents[0].graph.id == agent_id
mock_gdb.get_graph.assert_awaited_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_timeout_preserves_successful_fetches(self):
"""On timeout, agents that already fetched their graph keep the result."""
fast_agent = AgentInfo(
id="a1",
name="Fast",
description="d",
source="library",
graph_id="fast-graph",
)
slow_agent = AgentInfo(
id="a2",
name="Slow",
description="d",
source="library",
graph_id="slow-graph",
)
fake_graph = MagicMock()
fake_graph.id = "graph-1"
async def mock_get_graph(
graph_id, *, version=None, user_id=None, for_export=False
):
if graph_id == "fast-graph":
return fake_graph
await asyncio.sleep(999)
return MagicMock()
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(side_effect=mock_get_graph)
with (
patch("backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb),
patch("backend.copilot.tools.agent_search._GRAPH_FETCH_TIMEOUT", 0.1),
):
await _enrich_agents_with_graph([fast_agent, slow_agent], _TEST_USER_ID)
assert fast_agent.graph is fake_graph
assert slow_agent.graph is None
@pytest.mark.asyncio(loop_scope="session")
async def test_enrich_success(self):
"""All agents get their graphs when no timeout occurs."""
agent = AgentInfo(
id="a1", name="Test", description="d", source="library", graph_id="g1"
)
fake_graph = MagicMock()
fake_graph.id = "graph-1"
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
with patch(
"backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb
):
result = await _enrich_agents_with_graph([agent], _TEST_USER_ID)
assert agent.graph is fake_graph
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_enrich_skips_agents_without_graph_id(self):
"""Agents without graph_id are not fetched."""
agent_no_id = AgentInfo(
id="a1", name="Test", description="d", source="library", graph_id=None
)
mock_gdb = MagicMock()
mock_gdb.get_graph = AsyncMock()
with patch(
"backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb
):
result = await _enrich_agents_with_graph([agent_no_id], _TEST_USER_ID)
mock_gdb.get_graph.assert_not_called()
assert result is None

Some files were not shown because too many files have changed in this diff Show More