Compare commits

..

13 Commits

Author SHA1 Message Date
Bentlybro
77ebcfe55d ci: regenerate openapi.json after platform-linking schema changes 2026-04-02 15:54:21 +00:00
Bentlybro
db029f0b49 refactor(backend): split platform-linking routes into models, auth, routes, chat_proxy
- Split 628-line routes.py into 4 files under ~300 lines each:
  models.py (Pydantic models), auth.py (bot API key validation),
  routes.py (linking routes), chat_proxy.py (bot chat endpoints)
- Move all inline imports to top-level
- Read PLATFORM_BOT_API_KEY per-request instead of module-level global
- Read PLATFORM_LINK_BASE_URL per-request in create_link_token
- Use backend Settings().config.enable_auth for dev-mode bypass
  instead of raw ENV env var
- Add Path validation (max_length=64, regex) on token path params
- Update test imports to match new module paths
2026-04-01 11:38:56 +00:00
Bentlybro
a268291585 fix: Address review round 2 — all 4 blockers
1. **Broken bot auth (CRITICAL)** — Replaced `Depends(lambda req: ...)`
   with proper `async def get_bot_api_key(request: Request)` dependency.
   The lambda pattern caused FastAPI to interpret `req` as a required
   query parameter, making all bot endpoints return 422.

2. **Timing-safe key comparison** — Switched from `!=` to
   `hmac.compare_digest()` to prevent timing side-channel attacks.

3. **Removed dead code** — Deleted the unused `verify_bot_api_key`
   function that had a stub body passing all requests.

4. **TOCTOU race in confirm** — Wrapped `PlatformLink.prisma().create()`
   in try/except for unique constraint violations. Concurrent requests
   with different tokens for the same identity now get a clean 409
   instead of an unhandled 500.

Also: regenerated openapi.json (removes spurious `req` query parameter
that was leaking from the broken lambda pattern).
2026-03-31 16:03:08 +00:00
Bentlybro
9caca6d899 fix: correct unsubscribe_from_session parameter (subscriber_queue not user_id) 2026-03-31 15:56:09 +00:00
Bentlybro
481f704317 feat: Bot chat proxy — CoPilot streaming via bot API key
Adds two bot-facing endpoints that let the bot service call CoPilot
on behalf of linked users without needing their JWT:

- POST /api/platform-linking/chat/session — create a CoPilot session
- POST /api/platform-linking/chat/stream — send message + stream SSE response

Both authenticated via X-Bot-API-Key header. The bot passes user_id
from the platform-linking resolve step. Backend reuses the existing
CoPilot pipeline (enqueue_copilot_turn → Redis stream → SSE).

This is the last missing piece — the bot can now:
1. Resolve platform user → AutoGPT account (existing)
2. Create a chat session (new)
3. Stream CoPilot responses back to the user (new)
2026-03-31 15:48:31 +00:00
Bentlybro
781224f8d5 fix: use model_validate for invalid platform test (no type ignore) 2026-03-31 14:24:19 +00:00
Bentlybro
8a91bd8a99 fix: suppress pyright error on intentionally invalid test input 2026-03-31 14:22:55 +00:00
Bentlybro
7bf10c605a fix: Address review feedback on platform linking PR
Fixes all blockers and should-fix items from the automated review:

## Blockers fixed

1. **Bot-facing endpoint auth** — Added X-Bot-API-Key header auth for
   all bot-facing endpoints (POST /tokens, GET /tokens/{token}/status,
   POST /resolve). Configurable via PLATFORM_BOT_API_KEY env var.
   Dev mode allows keyless access.

2. **Race condition in confirm_link_token** — Replaced 5 separate DB
   calls with atomic update_many (WHERE token=X AND usedAt=NULL).
   If another request consumed the token first, returns 410 cleanly
   instead of a 500 unique constraint violation.

3. **Test coverage** — Added tests for: Platform enum, bot API key
   auth (valid/invalid/missing), request validation (empty IDs,
   too-long IDs, invalid platforms), response model typing.

## Should-fix items addressed

- **Enum duplication** — Created Python Platform(str, Enum) that
  mirrors the Prisma PlatformType. Used throughout request models
  for validation instead of bare strings.
- **Hardcoded base URL** — Now configurable via PLATFORM_LINK_BASE_URL
  env var. Defaults to https://platform.agpt.co/link.
- **Redundant @@index([token])** — Removed from schema and migration.
  The @unique already creates an index.
- **Literal status type** — LinkTokenStatusResponse.status is now
  Literal['pending', 'linked', 'expired'] not bare str.
- **delete_link returns dict** — Now returns DeleteLinkResponse model.
- **Input length validation** — Added min_length=1, max_length=255
  to platform_user_id and platform_username fields.
- **Token flooding** — Existing pending tokens are invalidated before
  creating a new one (only 1 active token per platform identity).
- **Resolve leaks user info** — Removed platform_username from
  resolve response (only returns user_id + linked boolean).
- **Error messages** — Sanitized to not expose internal IDs.
- **Logger** — Switched from f-strings to lazy %s formatting.
2026-03-31 14:10:53 +00:00
Bentlybro
7778de1d7b fix: import ordering (isort) 2026-03-31 13:00:41 +00:00
Bentlybro
015a626379 chore: update openapi.json with platform-linking endpoints 2026-03-31 12:41:49 +00:00
Bentlybro
be6501f10e fix: use Model.prisma() pattern for pyright compatibility
Switched from backend.data.db.get_prisma() to the standard
PlatformLink.prisma() / PlatformLinkToken.prisma() pattern
used throughout the codebase.
2026-03-31 12:25:26 +00:00
Bentlybro
32f6ef0a45 style: lint + format platform linking routes 2026-03-31 12:14:06 +00:00
Bentlybro
c7d5c1c844 feat: Platform bot linking API for multi-platform CoPilot
Adds the account linking system that enables CoPilot to work across
multiple chat platforms (Discord, Telegram, Slack, Teams, etc.) via
the Vercel Chat SDK.

## What this adds

### Database (Prisma + migration)
- PlatformType enum (DISCORD, TELEGRAM, SLACK, TEAMS, WHATSAPP, GITHUB, LINEAR)
- PlatformLink model - maps platform user IDs to AutoGPT accounts
  - Unique constraint on (platform, platformUserId)
  - One AutoGPT user can link multiple platforms
- PlatformLinkToken model - one-time tokens for the linking flow

### API endpoints (/api/platform-linking)

Bot-facing (called by the bot service):
- POST /tokens - Create a link token for an unlinked platform user
- GET /tokens/{token}/status - Check if linking is complete
- POST /resolve - Resolve platform identity → AutoGPT user ID

User-facing (JWT auth required):
- POST /tokens/{token}/confirm - Complete the link (user logs in first)
- GET /links - List all linked platform identities
- DELETE /links/{link_id} - Unlink a platform identity

## Linking flow
1. User messages bot on Discord/Telegram/etc
2. Bot calls POST /resolve → not linked
3. Bot calls POST /tokens → gets link URL
4. Bot sends user the link
5. User clicks → logs in to AutoGPT → frontend calls POST /confirm
6. Bot detects link on next message (or polls /status)
2026-03-31 12:10:15 +00:00
134 changed files with 3454 additions and 8770 deletions

View File

@@ -1 +0,0 @@
../.claude/skills

View File

@@ -1,6 +1,6 @@
# AutoGPT Platform Contribution Guide
This guide provides context for coding agents when updating the **autogpt_platform** folder.
This guide provides context for Codex when updating the **autogpt_platform** folder.
## Directory overview

View File

@@ -1 +0,0 @@
@AGENTS.md

View File

@@ -1,120 +0,0 @@
# AutoGPT Platform
This file provides guidance to coding agents when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1 +1,120 @@
@AGENTS.md
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -178,7 +178,6 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -1,227 +0,0 @@
# Backend
This file provides guidance to coding agents when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -1 +1,227 @@
@AGENTS.md
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -31,10 +31,7 @@ from backend.data.model import (
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -621,11 +618,6 @@ async def delete_credential(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id, dry_run=False)
return ChatSession.new(user_id)
@tools_router.post(

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -20,7 +20,6 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -113,25 +112,12 @@ class StreamChatRequest(BaseModel):
) # Workspace file IDs attached to this message
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -152,7 +138,6 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -263,7 +248,6 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -272,28 +256,22 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -442,7 +420,6 @@ async def get_session(
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=session.metadata,
)
@@ -1197,7 +1174,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id, dry_run=False)
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -469,60 +469,3 @@ def test_suggested_prompts_empty_prompts(
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422

View File

@@ -40,15 +40,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -114,7 +110,6 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -153,7 +148,6 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@@ -230,9 +224,6 @@ async def callback(
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -247,7 +238,6 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -342,11 +332,6 @@ async def delete_credentials(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
@@ -357,11 +342,6 @@ async def delete_credentials(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)

View File

@@ -1,7 +1,6 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import fastapi
import fastapi.testclient
@@ -277,294 +276,3 @@ class TestCreateCredentialNoSecretInResponse:
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -12,7 +12,6 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -59,27 +58,14 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture

View File

@@ -0,0 +1 @@
# Platform bot linking API

View File

@@ -0,0 +1,35 @@
"""Bot API key authentication for platform linking endpoints."""
import hmac
import os
from fastapi import HTTPException, Request
from backend.util.settings import Settings
async def get_bot_api_key(request: Request) -> str | None:
"""Extract the bot API key from the X-Bot-API-Key header."""
return request.headers.get("x-bot-api-key")
def check_bot_api_key(api_key: str | None) -> None:
"""Validate the bot API key. Uses constant-time comparison.
Reads the key from env on each call so rotated secrets take effect
without restarting the process.
"""
configured_key = os.getenv("PLATFORM_BOT_API_KEY", "")
if not configured_key:
settings = Settings()
if settings.config.enable_auth:
raise HTTPException(
status_code=503,
detail="Bot API key not configured.",
)
# Auth disabled (local dev) — allow without key
return
if not api_key or not hmac.compare_digest(api_key, configured_key):
raise HTTPException(status_code=401, detail="Invalid bot API key.")

View File

@@ -0,0 +1,170 @@
"""
Bot Chat Proxy endpoints.
Allows the bot service to send messages to CoPilot on behalf of
linked users, authenticated via bot API key.
"""
import asyncio
import logging
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from prisma.models import PlatformLink
from backend.copilot import stream_registry
from backend.copilot.executor.utils import enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
append_and_save_message,
create_chat_session,
get_chat_session,
)
from backend.copilot.response_model import StreamFinish
from .auth import check_bot_api_key, get_bot_api_key
from .models import BotChatRequest, BotChatSessionResponse
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/chat/session",
response_model=BotChatSessionResponse,
summary="Create a CoPilot session for a linked user (bot-facing)",
)
async def bot_create_session(
request: BotChatRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> BotChatSessionResponse:
"""Creates a new CoPilot chat session on behalf of a linked user."""
check_bot_api_key(x_bot_api_key)
link = await PlatformLink.prisma().find_first(where={"userId": request.user_id})
if not link:
raise HTTPException(status_code=404, detail="User has no platform links.")
session = await create_chat_session(request.user_id)
return BotChatSessionResponse(session_id=session.session_id)
@router.post(
"/chat/stream",
summary="Stream a CoPilot response for a linked user (bot-facing)",
)
async def bot_chat_stream(
request: BotChatRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
):
"""
Send a message to CoPilot on behalf of a linked user and stream
the response back as Server-Sent Events.
The bot authenticates with its API key — no user JWT needed.
"""
check_bot_api_key(x_bot_api_key)
user_id = request.user_id
# Verify user has a platform link
link = await PlatformLink.prisma().find_first(where={"userId": user_id})
if not link:
raise HTTPException(status_code=404, detail="User has no platform links.")
# Get or create session
session_id = request.session_id
if session_id:
session = await get_chat_session(session_id, user_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found.")
else:
session = await create_chat_session(user_id)
session_id = session.session_id
# Save user message
message = ChatMessage(role="user", content=request.message)
await append_and_save_message(session_id, message)
# Create a turn and enqueue
turn_id = str(uuid4())
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=True,
)
logger.info(
"Bot chat: user ...%s, session %s, turn %s",
user_id[-8:],
session_id,
turn_id,
)
async def event_generator():
subscriber_queue = None
try:
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if isinstance(chunk, str):
yield chunk
else:
yield chunk.to_sse()
if isinstance(chunk, StreamFinish) or (
isinstance(chunk, str) and "[DONE]" in chunk
):
break
except asyncio.TimeoutError:
yield ": keepalive\n\n"
except Exception:
logger.exception("Bot chat stream error for session %s", session_id)
yield 'data: {"type": "error", "content": "Stream error"}\n\n'
yield "data: [DONE]\n\n"
finally:
if subscriber_queue is not None:
await stream_registry.unsubscribe_from_session(
session_id=session_id,
subscriber_queue=subscriber_queue,
)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Session-Id": session_id,
},
)

View File

@@ -0,0 +1,107 @@
"""Pydantic models for the platform bot linking API."""
from datetime import datetime
from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field
class Platform(str, Enum):
"""Supported platform types (mirrors Prisma PlatformType)."""
DISCORD = "DISCORD"
TELEGRAM = "TELEGRAM"
SLACK = "SLACK"
TEAMS = "TEAMS"
WHATSAPP = "WHATSAPP"
GITHUB = "GITHUB"
LINEAR = "LINEAR"
# ── Request Models ─────────────────────────────────────────────────────
class CreateLinkTokenRequest(BaseModel):
"""Request from the bot service to create a linking token."""
platform: Platform = Field(description="Platform name")
platform_user_id: str = Field(
description="The user's ID on the platform",
min_length=1,
max_length=255,
)
platform_username: str | None = Field(
default=None,
description="Display name (best effort)",
max_length=255,
)
channel_id: str | None = Field(
default=None,
description="Channel ID for sending confirmation back",
max_length=255,
)
class ResolveRequest(BaseModel):
"""Resolve a platform identity to an AutoGPT user."""
platform: Platform
platform_user_id: str = Field(min_length=1, max_length=255)
class BotChatRequest(BaseModel):
"""Request from the bot to chat as a linked user."""
user_id: str = Field(description="The linked AutoGPT user ID")
message: str = Field(
description="The user's message", min_length=1, max_length=32000
)
session_id: str | None = Field(
default=None,
description="Existing chat session ID. If omitted, a new session is created.",
)
# ── Response Models ────────────────────────────────────────────────────
class LinkTokenResponse(BaseModel):
token: str
expires_at: datetime
link_url: str
class LinkTokenStatusResponse(BaseModel):
status: Literal["pending", "linked", "expired"]
user_id: str | None = None
class ResolveResponse(BaseModel):
linked: bool
user_id: str | None = None
class PlatformLinkInfo(BaseModel):
id: str
platform: str
platform_user_id: str
platform_username: str | None
linked_at: datetime
class ConfirmLinkResponse(BaseModel):
success: bool
platform: str
platform_user_id: str
platform_username: str | None
class DeleteLinkResponse(BaseModel):
success: bool
class BotChatSessionResponse(BaseModel):
"""Returned when creating a new session via the bot proxy."""
session_id: str

View File

@@ -0,0 +1,340 @@
"""
Platform Bot Linking API routes.
Enables linking external chat platform identities (Discord, Telegram, Slack, etc.)
to AutoGPT user accounts. Used by the multi-platform CoPilot bot.
Flow:
1. Bot calls POST /api/platform-linking/tokens to create a link token
for an unlinked platform user.
2. Bot sends the user a link: {frontend}/link/{token}
3. User clicks the link, logs in to AutoGPT, and the frontend calls
POST /api/platform-linking/tokens/{token}/confirm to complete the link.
4. Bot can poll GET /api/platform-linking/tokens/{token}/status or just
check on next message via GET /api/platform-linking/resolve.
"""
import logging
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Path, Security
from prisma.models import PlatformLink, PlatformLinkToken
from .auth import check_bot_api_key, get_bot_api_key
from .models import (
ConfirmLinkResponse,
CreateLinkTokenRequest,
DeleteLinkResponse,
LinkTokenResponse,
LinkTokenStatusResponse,
PlatformLinkInfo,
ResolveRequest,
ResolveResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter()
LINK_TOKEN_EXPIRY_MINUTES = 30
# Path parameter with validation for link tokens
TokenPath = Annotated[
str,
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
]
# ── Bot-facing endpoints (API key auth) ───────────────────────────────
@router.post(
"/tokens",
response_model=LinkTokenResponse,
summary="Create a link token for an unlinked platform user",
)
async def create_link_token(
request: CreateLinkTokenRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> LinkTokenResponse:
"""
Called by the bot service when it encounters an unlinked user.
Generates a one-time token the user can use to link their account.
"""
check_bot_api_key(x_bot_api_key)
platform = request.platform.value
# Check if already linked
existing = await PlatformLink.prisma().find_first(
where={
"platform": platform,
"platformUserId": request.platform_user_id,
}
)
if existing:
raise HTTPException(
status_code=409,
detail="This platform account is already linked.",
)
# Invalidate any existing pending tokens for this user
await PlatformLinkToken.prisma().update_many(
where={
"platform": platform,
"platformUserId": request.platform_user_id,
"usedAt": None,
},
data={"usedAt": datetime.now(timezone.utc)},
)
# Generate token
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(
minutes=LINK_TOKEN_EXPIRY_MINUTES
)
await PlatformLinkToken.prisma().create(
data={
"token": token,
"platform": platform,
"platformUserId": request.platform_user_id,
"platformUsername": request.platform_username,
"channelId": request.channel_id,
"expiresAt": expires_at,
}
)
logger.info(
"Created link token for %s (expires %s)",
platform,
expires_at.isoformat(),
)
link_base_url = os.getenv(
"PLATFORM_LINK_BASE_URL", "https://platform.agpt.co/link"
)
link_url = f"{link_base_url}/{token}"
return LinkTokenResponse(
token=token,
expires_at=expires_at,
link_url=link_url,
)
@router.get(
"/tokens/{token}/status",
response_model=LinkTokenStatusResponse,
summary="Check if a link token has been consumed",
)
async def get_link_token_status(
token: TokenPath,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> LinkTokenStatusResponse:
"""
Called by the bot service to check if a user has completed linking.
"""
check_bot_api_key(x_bot_api_key)
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise HTTPException(status_code=404, detail="Token not found")
if link_token.usedAt is not None:
# Token was used — find the linked account
link = await PlatformLink.prisma().find_first(
where={
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
}
)
return LinkTokenStatusResponse(
status="linked",
user_id=link.userId if link else None,
)
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
return LinkTokenStatusResponse(status="expired")
return LinkTokenStatusResponse(status="pending")
@router.post(
"/resolve",
response_model=ResolveResponse,
summary="Resolve a platform identity to an AutoGPT user",
)
async def resolve_platform_user(
request: ResolveRequest,
x_bot_api_key: str | None = Depends(get_bot_api_key),
) -> ResolveResponse:
"""
Called by the bot service on every incoming message to check if
the platform user has a linked AutoGPT account.
"""
check_bot_api_key(x_bot_api_key)
link = await PlatformLink.prisma().find_first(
where={
"platform": request.platform.value,
"platformUserId": request.platform_user_id,
}
)
if not link:
return ResolveResponse(linked=False)
return ResolveResponse(linked=True, user_id=link.userId)
# ── User-facing endpoints (JWT auth) ──────────────────────────────────
@router.post(
"/tokens/{token}/confirm",
response_model=ConfirmLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a link token (user must be authenticated)",
)
async def confirm_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmLinkResponse:
"""
Called by the frontend when the user clicks the link and is logged in.
Consumes the token and creates the platform link.
Uses atomic update_many to prevent race conditions on double-click.
"""
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise HTTPException(status_code=404, detail="Token not found.")
if link_token.usedAt is not None:
raise HTTPException(status_code=410, detail="This link has already been used.")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise HTTPException(status_code=410, detail="This link has expired.")
# Atomically mark token as used (only if still unused)
updated = await PlatformLinkToken.prisma().update_many(
where={"token": token, "usedAt": None},
data={"usedAt": datetime.now(timezone.utc)},
)
if updated == 0:
raise HTTPException(status_code=410, detail="This link has already been used.")
# Check if this platform identity is already linked
existing = await PlatformLink.prisma().find_first(
where={
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
}
)
if existing:
detail = (
"This platform account is already linked to your account."
if existing.userId == user_id
else "This platform account is already linked to another user."
)
raise HTTPException(status_code=409, detail=detail)
# Create the link — catch unique constraint race condition
try:
await PlatformLink.prisma().create(
data={
"userId": user_id,
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
"platformUsername": link_token.platformUsername,
}
)
except Exception as exc:
if "unique" in str(exc).lower():
raise HTTPException(
status_code=409,
detail="This platform account was just linked by another request.",
) from exc
raise
logger.info(
"Linked %s:%s to user ...%s",
link_token.platform,
link_token.platformUserId,
user_id[-8:],
)
return ConfirmLinkResponse(
success=True,
platform=link_token.platform,
platform_user_id=link_token.platformUserId,
platform_username=link_token.platformUsername,
)
@router.get(
"/links",
response_model=list[PlatformLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all platform links for the authenticated user",
)
async def list_my_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformLinkInfo]:
"""Returns all platform identities linked to the current user's account."""
links = await PlatformLink.prisma().find_many(
where={"userId": user_id},
order={"linkedAt": "desc"},
)
return [
PlatformLinkInfo(
id=link.id,
platform=link.platform,
platform_user_id=link.platformUserId,
platform_username=link.platformUsername,
linked_at=link.linkedAt,
)
for link in links
]
@router.delete(
"/links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a platform identity",
)
async def delete_link(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
"""
Removes a platform link. The user will need to re-link if they
want to use the bot on that platform again.
"""
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
if not link:
raise HTTPException(status_code=404, detail="Link not found.")
if link.userId != user_id:
raise HTTPException(status_code=403, detail="Not your link.")
await PlatformLink.prisma().delete(where={"id": link_id})
logger.info(
"Unlinked %s:%s from user ...%s",
link.platform,
link.platformUserId,
user_id[-8:],
)
return DeleteLinkResponse(success=True)

View File

@@ -0,0 +1,138 @@
"""Tests for platform bot linking API routes."""
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from backend.api.features.platform_linking.auth import check_bot_api_key
from backend.api.features.platform_linking.models import (
ConfirmLinkResponse,
CreateLinkTokenRequest,
DeleteLinkResponse,
LinkTokenStatusResponse,
Platform,
ResolveRequest,
)
class TestPlatformEnum:
def test_all_platforms_exist(self):
assert Platform.DISCORD.value == "DISCORD"
assert Platform.TELEGRAM.value == "TELEGRAM"
assert Platform.SLACK.value == "SLACK"
assert Platform.TEAMS.value == "TEAMS"
assert Platform.WHATSAPP.value == "WHATSAPP"
assert Platform.GITHUB.value == "GITHUB"
assert Platform.LINEAR.value == "LINEAR"
class TestBotApiKeyAuth:
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
@patch("backend.api.features.platform_linking.auth.Settings")
def test_no_key_configured_allows_when_auth_disabled(self, mock_settings_cls):
mock_settings_cls.return_value.config.enable_auth = False
check_bot_api_key(None)
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
@patch("backend.api.features.platform_linking.auth.Settings")
def test_no_key_configured_rejects_when_auth_enabled(self, mock_settings_cls):
mock_settings_cls.return_value.config.enable_auth = True
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key(None)
assert exc_info.value.status_code == 503
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_valid_key(self):
check_bot_api_key("secret123")
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_invalid_key_rejected(self):
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key("wrong")
assert exc_info.value.status_code == 401
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
def test_missing_key_rejected(self):
with pytest.raises(HTTPException) as exc_info:
check_bot_api_key(None)
assert exc_info.value.status_code == 401
class TestCreateLinkTokenRequest:
def test_valid_request(self):
req = CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="353922987235213313",
)
assert req.platform == Platform.DISCORD
assert req.platform_user_id == "353922987235213313"
def test_empty_platform_user_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="",
)
def test_too_long_platform_user_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="x" * 256,
)
def test_invalid_platform_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
CreateLinkTokenRequest.model_validate(
{"platform": "INVALID", "platform_user_id": "123"}
)
class TestResolveRequest:
def test_valid_request(self):
req = ResolveRequest(
platform=Platform.TELEGRAM,
platform_user_id="123456789",
)
assert req.platform == Platform.TELEGRAM
def test_empty_id_rejected(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
ResolveRequest(
platform=Platform.SLACK,
platform_user_id="",
)
class TestResponseModels:
def test_link_token_status_literal(self):
resp = LinkTokenStatusResponse(status="pending")
assert resp.status == "pending"
resp = LinkTokenStatusResponse(status="linked", user_id="abc")
assert resp.status == "linked"
resp = LinkTokenStatusResponse(status="expired")
assert resp.status == "expired"
def test_confirm_link_response(self):
resp = ConfirmLinkResponse(
success=True,
platform="DISCORD",
platform_user_id="123",
platform_username="testuser",
)
assert resp.success is True
def test_delete_link_response(self):
resp = DeleteLinkResponse(success=True)
assert resp.success is True

View File

@@ -30,6 +30,8 @@ import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.chat_proxy
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
@@ -118,11 +120,6 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -366,6 +363,16 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.include_router(
backend.api.features.platform_linking.chat_proxy.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.mount("/external-api", external_api)

View File

@@ -146,21 +146,6 @@ class AutoPilotBlock(Block):
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -247,11 +232,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
async def create_session(self, user_id: str) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return session.session_id
async def execute_copilot(
@@ -382,9 +367,7 @@ class AutoPilotBlock(Block):
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
sid = await self.create_session(execution_context.user_id)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -9,7 +8,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -43,52 +42,8 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
def _make_mime_text(
@@ -145,16 +100,14 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = serialize_email_recipients(input_data.cc)
message["cc"] = ", ".join(input_data.cc)
if input_data.bcc:
message["bcc"] = serialize_email_recipients(input_data.bcc)
message["bcc"] = ", ".join(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1214,15 +1167,13 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1734,16 +1685,13 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -724,9 +724,6 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -742,9 +739,6 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -978,8 +972,6 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1039,8 +1031,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"OpenRouter returned empty choices: {response}")
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1077,8 +1073,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"Llama API returned empty choices: {response}")
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1108,8 +1108,6 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1146,9 +1144,6 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -2016,19 +2011,6 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

File diff suppressed because it is too large Load Diff

View File

@@ -488,154 +488,6 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""

View File

@@ -1,87 +0,0 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,7 +1074,6 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1106,7 +1105,6 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -1,202 +0,0 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

View File

@@ -9,16 +9,12 @@ shared tool registry as the SDK path.
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.context import set_execution_context
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -52,17 +48,7 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallResult,
tool_call_loop,
)
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -73,247 +59,6 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
can be module-level functions instead of deeply nested closures.
"""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
async def _baseline_llm_caller(
messages: list[dict[str, Any]],
tools: Sequence[Any],
*,
state: _BaselineStreamState,
) -> LLMLoopResponse:
"""Stream an OpenAI-compatible response and collect results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
round_text = ""
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
stream_options={"include_usage": True},
)
else:
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
)
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += delta.content
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=delta.content)
)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
LLMToolCall(
id=tc["id"],
name=tc["name"],
arguments=tc["arguments"] or "{}",
)
for tc in tool_calls_by_index.values()
]
return LLMLoopResponse(
response_text=round_text or None,
tool_calls=llm_tool_calls,
raw_response=None, # Not needed for baseline conversation updater
prompt_tokens=0, # Tracked via state accumulators
completion_tokens=0,
)
async def _baseline_tool_executor(
tool_call: LLMToolCall,
tools: Sequence[Any],
*,
state: _BaselineStreamState,
user_id: str | None,
session: ChatSession,
) -> ToolCallResult:
"""Execute a tool via the copilot tool registry.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
tool_call_id = tool_call.id
tool_name = tool_call.name
raw_args = tool_call.arguments or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=parse_error,
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
)
state.pending_events.append(
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
)
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=tool_output,
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=error_output,
is_error=True,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": tc.arguments},
}
for tc in response.tool_calls
]
messages.append(assistant_msg)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
@@ -458,9 +203,6 @@ async def stream_chat_completion_baseline(
tools = get_available_tools()
# Propagate execution context so tool handlers can read session-level flags.
set_execution_context(user_id, session)
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
@@ -477,32 +219,191 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
_bound_tool_executor = partial(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
try:
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
if loop_result and not loop_result.finished_naturally:
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
@@ -517,28 +418,11 @@ async def stream_chat_completion_baseline(
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text block. The llm_caller's finally block
# already appended StreamFinishStep to pending_events, so we must
# insert StreamTextEnd *before* StreamFinishStep to preserve the
# protocol ordering:
# StreamStartStep -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Appending (or yielding directly) would place it after
# StreamFinishStep, violating the protocol.
if state.text_started:
# Find the last StreamFinishStep and insert before it.
insert_pos = len(state.pending_events)
for i in range(len(state.pending_events) - 1, -1, -1):
if isinstance(state.pending_events[i], StreamFinishStep):
insert_pos = i
break
state.pending_events.insert(
insert_pos, StreamTextEnd(id=state.text_block_id)
)
# Drain pending events in correct order
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
@@ -558,21 +442,26 @@ async def stream_chat_completion_baseline(
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
state.turn_prompt_tokens == 0
and state.turn_completion_tokens == 0
and not (_stream_error and not state.assistant_text)
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
):
state.turn_prompt_tokens = max(
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
state.turn_prompt_tokens,
state.turn_completion_tokens,
turn_prompt_tokens,
turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
@@ -582,15 +471,15 @@ async def stream_chat_completion_baseline(
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if state.assistant_text:
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=state.assistant_text)
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
@@ -602,11 +491,11 @@ async def stream_chat_completion_baseline(
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id, dry_run=False)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -178,7 +178,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``build_sdk_env``.
present — mirrors the fallback logic in ``_build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -18,13 +18,7 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import (
ChatMessage,
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
invalidate_session_cache,
)
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
logger = logging.getLogger(__name__)
@@ -41,7 +35,6 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
async def create_chat_session(
session_id: str,
user_id: str,
metadata: ChatSessionMetadata | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -50,7 +43,6 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -65,12 +57,7 @@ async def update_chat_session(
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
"""Update a chat session's mutable fields.
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
it is set once at creation time and treated as immutable for the lifetime
of the session.
"""
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:

View File

@@ -46,16 +46,6 @@ def _get_session_cache_key(session_id: str) -> str:
# ===================== Chat data models ===================== #
class ChatSessionMetadata(BaseModel):
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
Add new session-level flags here instead of adding DB columns —
no migration required for new fields as long as a default is provided.
"""
dry_run: bool = False
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -100,12 +90,6 @@ class ChatSessionInfo(BaseModel):
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
metadata: ChatSessionMetadata = ChatSessionMetadata()
@property
def dry_run(self) -> bool:
"""Convenience accessor for ``metadata.dry_run``."""
return self.metadata.dry_run
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -119,10 +103,6 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Parse typed metadata from the JSON column.
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
metadata = ChatSessionMetadata.model_validate(raw_metadata)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
@@ -148,7 +128,6 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
metadata=metadata,
)
@@ -156,7 +135,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str, *, dry_run: bool) -> Self:
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -166,7 +145,6 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
)
@classmethod
@@ -554,7 +532,6 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
metadata=session.metadata,
)
existing_message_count = 0
@@ -632,27 +609,21 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id, dry_run=dry_run)
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
metadata=session.metadata,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")

View File

@@ -46,7 +46,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123", dry_run=False)
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
@@ -57,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -75,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -90,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
@@ -241,7 +241,7 @@ _raw_tc2 = {
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
@@ -254,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
@@ -267,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
@@ -282,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
@@ -300,7 +300,7 @@ def test_add_tool_call_multiple_times():
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
@@ -352,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(

View File

@@ -107,13 +107,6 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
After building the file, reference it with `@@agptfile:` in other tools:
`@@agptfile:/home/user/report.md`
### Web search best practices
- If 3 similar web searches don't return the specific data you need, conclude
it isn't publicly available and work with what you have.
- Prefer fewer, well-targeted searches over many variations of the same query.
- When spawning sub-agents for research, ensure each has a distinct
non-overlapping scope to avoid redundant searches.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.

View File

@@ -1,21 +0,0 @@
"""Tests for agent generation guide — verifies clarification section."""
from pathlib import Path
class TestAgentGenerationGuideContainsClarifySection:
"""The agent generation guide must include the clarification section."""
def test_guide_includes_clarify_before_building(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
assert "Clarifying Before Building" in content
def test_guide_mentions_find_block_for_clarification(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
# find_block must appear in the clarification section (before the workflow)
clarify_section = content.split("Clarifying Before Building")[1].split(
"### Workflow"
)[0]
assert "find_block" in clarify_section

View File

@@ -3,21 +3,6 @@
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Clarifying Before Building
Before starting the workflow below, check whether the user's goal is
**ambiguous** — missing the output format, delivery channel, data source,
or trigger. If so:
1. Call `find_block` with a query targeting the ambiguous dimension to
discover what the platform actually supports.
2. Ask the user **one concrete question** grounded in the discovered
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
which should the agent use for delivery?").
3. **Wait for the user's answer** before proceeding.
**Skip this** when the goal already specifies all dimensions (e.g.
"scrape prices from Amazon and email me daily").
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to

View File

@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user", dry_run=False)
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------

View File

@@ -275,7 +275,7 @@ class TestCompactionE2E:
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user", dry_run=False)
session = ChatSession.new(user_id="test-user")
tracker.on_compact(str(session_file))
# --- Step 8: Next SDK message arrives → emit_start ---
@@ -376,7 +376,7 @@ class TestCompactionE2E:
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test", dry_run=False)
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---

View File

@@ -1,68 +0,0 @@
"""SDK environment variable builder — importable without circular deps.
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
can reuse the same subscription / OpenRouter / direct-Anthropic logic
without pulling in the full copilot service module (which would create a
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
"""
from __future__ import annotations
from backend.copilot.config import ChatConfig
from backend.copilot.sdk.subscription import validate_subscription
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
# A singleton would require importing service.py which causes the circular dep
# this module was created to avoid.
config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env

View File

@@ -1,242 +0,0 @@
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
from unittest.mock import patch
import pytest
from backend.copilot.config import ChatConfig
# ---------------------------------------------------------------------------
# Helpers — build a ChatConfig with explicit field values so tests don't
# depend on real environment variables.
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
}
defaults.update(overrides)
return ChatConfig(**defaults)
# ---------------------------------------------------------------------------
# Mode 1 — Subscription auth
# ---------------------------------------------------------------------------
class TestBuildSdkEnvSubscription:
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_returns_blanked_keys(self, mock_validate):
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
mock_validate.assert_called_once()
@patch(
"backend.copilot.sdk.env.validate_subscription",
side_effect=RuntimeError("CLI not found"),
)
def test_propagates_validation_error(self, mock_validate):
"""If validate_subscription fails, the error bubbles up."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
with pytest.raises(RuntimeError, match="CLI not found"):
build_sdk_env()
# ---------------------------------------------------------------------------
# Mode 2 — Direct Anthropic (no OpenRouter)
# ---------------------------------------------------------------------------
class TestBuildSdkEnvDirectAnthropic:
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
def test_returns_empty_dict_when_openrouter_inactive(self):
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
# Force api_key to None after construction (field_validator may pick up env vars)
object.__setattr__(cfg, "api_key", None)
assert not cfg.openrouter_active
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
# ---------------------------------------------------------------------------
# Mode 3 — OpenRouter proxy
# ---------------------------------------------------------------------------
class TestBuildSdkEnvOpenRouter:
"""When OpenRouter is active, return proxy env vars."""
def _openrouter_config(self, **overrides):
defaults = {
"use_openrouter": True,
"api_key": "sk-or-test-key",
"base_url": "https://openrouter.ai/api/v1",
}
defaults.update(overrides)
return _make_config(**defaults)
def test_basic_openrouter_env(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
assert result["ANTHROPIC_API_KEY"] == ""
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
def test_strips_trailing_v1(self):
"""The /v1 suffix is stripped from the base URL."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_strips_trailing_v1_and_slash(self):
"""Trailing slash before /v1 strip is handled."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# rstrip("/") first, then remove /v1
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_no_v1_suffix_left_alone(self):
"""A base URL without /v1 is used as-is."""
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
def test_session_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="sess-123")
assert "ANTHROPIC_CUSTOM_HEADERS" in result
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_user_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(user_id="user-456")
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_both_headers(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="s1", user_id="u2")
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
assert "x-session-id: s1" in headers
assert "x-user-id: u2" in headers
# They should be newline-separated
assert "\n" in headers
def test_header_sanitisation_strips_newlines(self):
"""Newlines/carriage-returns in header values are stripped."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="bad\r\nvalue")
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
# The _safe helper removes \r and \n
assert "\r" not in header_val.split(": ", 1)[1]
assert "badvalue" in header_val
def test_header_value_truncated_to_128_chars(self):
"""Header values are truncated to 128 characters."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
long_id = "x" * 200
result = build_sdk_env(session_id=long_id)
# The value after "x-session-id: " should be at most 128 chars
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
value = header_line.split(": ", 1)[1]
assert len(value) == 128
# ---------------------------------------------------------------------------
# Mode priority
# ---------------------------------------------------------------------------
class TestBuildSdkEnvModePriority:
"""Subscription mode takes precedence over OpenRouter."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_subscription_overrides_openrouter(self, mock_validate):
cfg = _make_config(
use_claude_code_subscription=True,
use_openrouter=True,
api_key="sk-or-key",
base_url="https://openrouter.ai/api/v1",
)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# Should get subscription result, not OpenRouter
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}

View File

@@ -38,7 +38,7 @@ class TestFlattenAssistantContent:
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
assert _flatten_assistant_content(blocks) == ""
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
@@ -47,22 +47,19 @@ class TestFlattenAssistantContent:
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
# tool_use blocks are dropped entirely to prevent model mimicry
assert "Read" not in result
assert "[tool_use: Read]" in result
def test_raw_strings(self):
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_unknown_block_type_dropped(self):
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [
{"type": "text", "text": "See this image:"},
{"type": "image", "source": {"type": "base64", "data": "..."}},
]
result = _flatten_assistant_content(blocks)
assert "See this image:" in result
# Unknown block types are dropped to prevent model mimicry
assert "[__image__]" not in result
assert "base64" not in result
assert "[__image__]" in result
def test_empty(self):
assert _flatten_assistant_content([]) == ""
@@ -282,8 +279,7 @@ class TestTranscriptToMessages:
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert "Let me check." in messages[0]["content"]
# tool_use blocks are dropped entirely to prevent model mimicry
assert "read_file" not in messages[0]["content"]
assert "[tool_use: read_file]" in messages[0]["content"]
assert messages[1]["content"] == "file contents"

View File

@@ -49,22 +49,22 @@ def test_format_assistant_tool_calls():
)
]
result = _format_conversation_context(msgs)
# Assistant with no content and tool_calls omitted produces no lines
assert result is None
assert result is not None
assert 'You called tool: search({"q": "test"})' in result
def test_format_tool_result():
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
result = _format_conversation_context(msgs)
assert result is not None
assert 'Tool output: {"result": "ok"}' in result
assert 'Tool result: {"result": "ok"}' in result
def test_format_tool_result_none_content():
msgs = [ChatMessage(role="tool", content=None)]
result = _format_conversation_context(msgs)
assert result is not None
assert "Tool output: " in result
assert "Tool result: " in result
def test_format_full_conversation():
@@ -84,8 +84,8 @@ def test_format_full_conversation():
assert result is not None
assert "User: find agents" in result
assert "You responded: I'll search for agents." in result
# tool_calls are omitted to prevent model mimicry
assert "Tool output:" in result
assert "You called tool: find_agents" in result
assert "Tool result:" in result
assert "You responded: Found Agent1." in result

View File

@@ -27,7 +27,6 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -77,12 +76,6 @@ class SDKResponseAdapter:
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif sdk_message.subtype == "task_progress":
# Emit a heartbeat so publish_chunk is called during long
# sub-agent runs. Without this, the Redis stream and meta
# key TTLs expire during gaps where no real chunks are
# produced (task_progress events were previously silent).
responses.append(StreamHeartbeat())
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage

View File

@@ -18,7 +18,6 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -60,14 +59,6 @@ def test_system_non_init_emits_nothing():
assert results == []
def test_task_progress_emits_heartbeat():
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
assert len(results) == 1
assert isinstance(results[0], StreamHeartbeat)
# -- AssistantMessage with TextBlock -----------------------------------------

View File

@@ -904,14 +904,14 @@ class TestTranscriptEdgeCases:
assert restored[1]["content"] == "Second"
def test_flatten_assistant_with_only_tool_use(self):
"""Assistant message with only tool_use blocks (no text) flattens to empty."""
"""Assistant message with only tool_use blocks (no text)."""
blocks = [
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
]
result = _flatten_assistant_content(blocks)
# tool_use blocks are dropped entirely to prevent model mimicry
assert result == ""
assert "[tool_use: bash]" in result
assert "[tool_use: read]" in result
def test_flatten_tool_result_nested_image(self):
"""Tool result containing image blocks uses placeholder."""
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
(f"{_SVC}.set_execution_context", {}),
(
@@ -1414,76 +1414,3 @@ class TestStreamChatCompletionRetryIntegration:
# Verify user-friendly message (not raw SDK text)
assert "Authentication" in errors[0].errorText
assert any(isinstance(e, StreamStart) for e in events)
@pytest.mark.asyncio
async def test_result_message_prompt_too_long_triggers_compaction(self):
"""CLI returns ResultMessage(subtype="error") with "Prompt is too long".
When the Claude CLI rejects the prompt pre-API (model=<synthetic>,
duration_api_ms=0), it sends a ResultMessage with is_error=True
instead of raising a Python exception. The retry loop must still
detect this as a context-length error and trigger compaction.
"""
import contextlib
from claude_agent_sdk import ResultMessage
from backend.copilot.response_model import StreamError, StreamStart
from backend.copilot.sdk.service import stream_chat_completion_sdk
session = self._make_session()
success_result = self._make_result_message()
attempt_count = [0]
error_result = ResultMessage(
subtype="error",
result="Prompt is too long",
duration_ms=100,
duration_api_ms=0,
is_error=True,
num_turns=0,
session_id="test-session-id",
)
def _client_factory(*args, **kwargs):
attempt_count[0] += 1
if attempt_count[0] == 1:
# First attempt: CLI returns error ResultMessage
return self._make_client_mock(result_message=error_result)
# Second attempt (after compaction): succeeds
return self._make_client_mock(result_message=success_result)
original_transcript = _build_transcript(
[("user", "prior question"), ("assistant", "prior answer")]
)
compacted_transcript = _build_transcript(
[("user", "[summary]"), ("assistant", "summary reply")]
)
patches = _make_sdk_patches(
session,
original_transcript=original_transcript,
compacted_transcript=compacted_transcript,
client_side_effect=_client_factory,
)
events = []
with contextlib.ExitStack() as stack:
for target, kwargs in patches:
stack.enter_context(patch(target, **kwargs))
async for event in stream_chat_completion_sdk(
session_id="test-session-id",
message="hello",
is_user_message=True,
user_id="test-user",
session=session,
):
events.append(event)
assert attempt_count[0] == 2, (
f"Expected 2 SDK attempts (CLI error ResultMessage "
f"should trigger compaction retry), got {attempt_count[0]}"
)
errors = [e for e in events if isinstance(e, StreamError)]
assert not errors, f"Unexpected StreamError: {errors}"
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -313,7 +313,8 @@ def create_security_hooks(
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
trigger,
user_id,
transcript_path,

View File

@@ -11,11 +11,7 @@ import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import (
_validate_tool_access,
_validate_user_isolation,
create_security_hooks,
)
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
@@ -224,6 +220,8 @@ def test_bash_builtin_blocked_message_clarity():
@pytest.fixture()
def _hooks():
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
post = hooks["PostToolUse"][0].hooks[0]

View File

@@ -59,14 +59,11 @@ from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamStatus,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
@@ -80,13 +77,15 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
cancel_pending_tool_tasks,
create_copilot_mcp_server,
get_copilot_tool_names,
get_sdk_disallowed_tools,
pre_launch_tool_call,
reset_stash_event,
reset_tool_failure_counters,
set_execution_context,
@@ -116,10 +115,9 @@ _MAX_STREAM_ATTEMPTS = 3
# Hard circuit breaker: abort the stream if the model sends this many
# consecutive tool calls with empty parameters (a sign of context
# saturation or serialization failure). The MCP wrapper now returns
# guidance on the first empty call, giving the model a chance to
# self-correct. The limit is generous to allow recovery attempts.
_EMPTY_TOOL_CALL_LIMIT = 5
# saturation or serialization failure). Empty input ({}) is never
# legitimate — even one is suspicious, three is conclusive.
_EMPTY_TOOL_CALL_LIMIT = 3
# User-facing error shown when the empty-tool-call circuit breaker trips.
_CIRCUIT_BREAKER_ERROR_MSG = (
@@ -569,6 +567,60 @@ def _resolve_sdk_model() -> str | None:
return model
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
2. **Direct Anthropic** — returns `{}`; subprocess inherits
`ANTHROPIC_API_KEY` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
_validate_claude_code_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
# `openrouter_active` checks the flag *and* credential presence.
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
# Strip /v1 suffix — SDK expects the base URL without a version path.
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
"""Sanitise a header value: strip newlines/whitespace and cap length."""
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -748,11 +800,15 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
elif msg.role == "assistant":
if msg.content:
lines.append(f"You responded: {msg.content}")
# Omit tool_calls — any text representation gets mimicked
# by the model. Tool results below provide the context.
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
tool_name = func.get("name", "unknown")
tool_args = func.get("arguments", "")
lines.append(f"You called tool: {tool_name}({tool_args})")
elif msg.role == "tool":
content = msg.content or ""
lines.append(f"Tool output: {content[:500]}")
lines.append(f"Tool result: {content}")
if not lines:
return None
@@ -1212,14 +1268,6 @@ async def _run_stream_attempt(
consecutive_empty_tool_calls = 0
# --- Intermediate persistence tracking ---
# Flush session messages to DB periodically so page reloads show progress
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
_last_flush_time = time.monotonic()
_msgs_since_flush = 0
_FLUSH_INTERVAL_SECONDS = 30.0
_FLUSH_MESSAGE_THRESHOLD = 10
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
# suppress SDK cleanup errors that occur when the SSE client disconnects
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
@@ -1306,21 +1354,6 @@ async def _run_stream_attempt(
error_preview,
)
# Intercept prompt-too-long errors surfaced as
# AssistantMessage.error (not as a Python exception).
# Re-raise so the outer retry loop can compact the
# transcript and retry with reduced context.
# Only check error_text (the error field), not the
# content preview — content may contain arbitrary text
# that false-positives the pattern match.
if _is_prompt_too_long(Exception(error_text)):
logger.warning(
"%s Prompt-too-long detected via AssistantMessage "
"error — raising for retry",
ctx.log_prefix,
)
raise RuntimeError("Prompt is too long")
# Intercept transient API errors (socket closed,
# ECONNRESET) — replace the raw message with a
# user-friendly error text and use the retryable
@@ -1348,17 +1381,28 @@ async def _run_stream_attempt(
ended_with_stream_error = True
break
# Determine if the message is a tool-only batch (all content
# Parallel tool execution: pre-launch every ToolUseBlock as an
# asyncio.Task the moment its AssistantMessage arrives. The SDK
# sends one AssistantMessage per tool call when issuing parallel
# calls, so each message is pre-launched independently. The MCP
# handlers will await the already-running task instead of executing
# fresh, making all concurrent tool calls run in parallel.
#
# Also determine if the message is a tool-only batch (all content
# items are ToolUseBlocks) — such messages have no text output yet,
# so we skip the wait_for_stash flush below.
#
# Note: parallel execution of tools is handled natively by the
# SDK CLI via readOnlyHint annotations on tool definitions.
is_tool_only = False
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
is_tool_only = all(
isinstance(item, ToolUseBlock) for item in sdk_msg.content
)
is_tool_only = True
# NOTE: Pre-launches are sequential (each await completes
# file-ref expansion before the next starts). This is fine
# since expansion is typically sub-ms; a future optimisation
# could gather all pre-launches concurrently.
for tool_use in sdk_msg.content:
if isinstance(tool_use, ToolUseBlock):
await pre_launch_tool_call(tool_use.name, tool_use.input)
else:
is_tool_only = False
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
@@ -1414,13 +1458,6 @@ async def _run_stream_attempt(
ctx.log_prefix,
sdk_msg.result or "(no error message provided)",
)
# If the CLI itself rejected the prompt as too long
# (pre-API check, duration_api_ms=0), re-raise as an
# exception so the retry loop can trigger compaction.
# Without this, the ResultMessage is silently consumed
# and the retry/compaction mechanism is never invoked.
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
raise RuntimeError("Prompt is too long")
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
@@ -1499,34 +1536,6 @@ async def _run_stream_attempt(
model=sdk_msg.model,
)
# --- Intermediate persistence ---
# Flush session messages to DB periodically so page reloads
# show progress during long-running turns.
_msgs_since_flush += 1
now = time.monotonic()
if (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
):
try:
await asyncio.shield(upsert_chat_session(ctx.session))
logger.debug(
"%s Intermediate flush: %d messages "
"(msgs_since=%d, elapsed=%.1fs)",
ctx.log_prefix,
len(ctx.session.messages),
_msgs_since_flush,
now - _last_flush_time,
)
except Exception as flush_err:
logger.warning(
"%s Intermediate flush failed: %s",
ctx.log_prefix,
flush_err,
)
_last_flush_time = now
_msgs_since_flush = 0
if acc.stream_completed:
break
finally:
@@ -1858,7 +1867,7 @@ async def stream_chat_completion_sdk(
)
# Fail fast when no API credentials are available at all.
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
if not config.api_key and not config.use_claude_code_subscription:
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY, "
@@ -2053,22 +2062,13 @@ async def stream_chat_completion_sdk(
try:
async for event in _run_stream_attempt(stream_ctx, state):
if not isinstance(
event,
(
StreamHeartbeat,
# Compaction UI events are cosmetic and must not
# block retry — they're emitted before the SDK
# query on compacted attempts.
StreamStartStep,
StreamFinishStep,
StreamToolInputStart,
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
if not isinstance(event, StreamHeartbeat):
events_yielded += 1
yield event
# Cancel any pre-launched tasks that were never dispatched
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
# with the three error-path await cancel_pending_tool_tasks() calls.
await cancel_pending_tool_tasks()
break # Stream completed — exit retry loop
except asyncio.CancelledError:
logger.warning(
@@ -2077,6 +2077,9 @@ async def stream_chat_completion_sdk(
attempt + 1,
_MAX_STREAM_ATTEMPTS,
)
# Cancel any pre-launched tasks so they don't continue executing
# against a rolled-back or abandoned session.
await cancel_pending_tool_tasks()
raise
except _HandledStreamError as exc:
# _run_stream_attempt already yielded a StreamError and
@@ -2108,6 +2111,8 @@ async def stream_chat_completion_sdk(
retryable=True,
)
ended_with_stream_error = True
# Cancel any pre-launched tasks from the failed attempt.
await cancel_pending_tool_tasks()
break
except Exception as e:
stream_err = e
@@ -2124,6 +2129,9 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
session.messages = session.messages[:pre_attempt_msg_count]
# Cancel any pre-launched tasks from the failed attempt so they
# don't continue executing against the rolled-back session.
await cancel_pending_tool_tasks()
if events_yielded > 0:
# Events were already sent to the frontend and cannot be
# unsent. Retrying would produce duplicate/inconsistent

View File

@@ -392,7 +392,7 @@ class TestFlattenThinkingBlocks:
assert result == ""
def test_mixed_thinking_text_tool(self):
"""Mixed blocks: only text survives flattening; thinking and tool_use dropped."""
"""Mixed blocks: only text and tool_use survive flattening."""
blocks = [
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
{"type": "redacted_thinking", "data": "xyz"},
@@ -403,8 +403,7 @@ class TestFlattenThinkingBlocks:
assert "hmm" not in result
assert "xyz" not in result
assert "I'll read the file." in result
# tool_use blocks are dropped entirely to prevent model mimicry
assert "Read" not in result
assert "[tool_use: Read]" in result
# ---------------------------------------------------------------------------

View File

@@ -14,7 +14,6 @@ from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from mcp.types import ToolAnnotations
from backend.copilot.context import (
_current_permissions,
@@ -54,6 +53,14 @@ _MCP_MAX_CHARS = 500_000
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
# Initialised per-session in set_execution_context() so concurrent sessions
# never share the same dict.
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
ContextVar("_tool_task_queues", default=None)
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -108,6 +115,7 @@ def set_execution_context(
_current_permissions.set(permissions)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_tool_task_queues.set({})
_consecutive_tool_failures.set({})
@@ -124,6 +132,48 @@ def reset_stash_event() -> None:
event.clear()
async def cancel_pending_tool_tasks() -> None:
"""Cancel all queued pre-launched tasks for the current execution context.
Call this when a stream attempt aborts (error, cancellation) to prevent
pre-launched tasks from continuing to execute against a rolled-back session.
Tasks that are already done are skipped; in-flight tasks are cancelled and
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
before the next retry starts.
"""
queues = _tool_task_queues.get()
if not queues:
return
cancelled_tasks: list[asyncio.Task] = []
for tool_name, queue in list(queues.items()):
cancelled = 0
while not queue.empty():
task, _args = queue.get_nowait()
if not task.done():
task.cancel()
cancelled_tasks.append(task)
cancelled += 1
if cancelled:
logger.debug(
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
)
queues.clear()
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
# completes before the next retry attempt starts new pre-launches.
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
if cancelled_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*cancelled_tasks, return_exceptions=True),
timeout=5.0,
)
except TimeoutError:
logger.warning(
"Timed out waiting for %d cancelled task(s) to clean up",
len(cancelled_tasks),
)
def reset_tool_failure_counters() -> None:
"""Reset all tool-level circuit breaker counters.
@@ -199,6 +249,10 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
the timeout is purely a safety net for the case where the hook never fires.
Returns ``True`` if the stash signal was received, ``False`` on timeout.
The 2.0 s default was chosen to accommodate slower tool startup in cloud
sandboxes while still failing fast when the hook genuinely will not fire.
With the parallel pre-launch path, hooks typically fire well under 1 ms.
"""
event = _stash_event.get(None)
if event is None:
@@ -217,13 +271,95 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
return False
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
"""Pre-launch a tool as a background task so parallel calls run concurrently.
Called when an AssistantMessage with ToolUseBlocks is received, before the
SDK dispatches the MCP tool/call requests. The tool_handler will await the
pre-launched task instead of executing fresh.
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
the prefix is stripped automatically before looking up the tool.
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
in the same order as the ToolUseBlocks appear in the AssistantMessage.
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
given tool name dequeues the N-th pre-launched task — result and args always
correspond when the SDK preserves order (which it does in the current SDK).
"""
queues = _tool_task_queues.get()
if queues is None:
return
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
base_tool = TOOL_REGISTRY.get(bare_name)
if base_tool is None:
return
user_id, session = get_execution_context()
if session is None:
return
# Expand @@agptfile: references before launching the task.
# The _truncating wrapper (which normally handles expansion) runs AFTER
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
input_schema: dict[str, Any] | None
try:
input_schema = _build_input_schema(base_tool)
except Exception:
input_schema = None # schema unavailable — skip schema-aware expansion
try:
args = await expand_file_refs_in_args(
args, user_id, session, input_schema=input_schema
)
except FileRefExpansionError as exc:
logger.warning(
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
bare_name,
exc,
)
return
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
# Log unhandled exceptions so "Task exception was never retrieved" warnings
# do not pollute stderr when a task is pre-launched but never dequeued.
task.add_done_callback(
lambda t, name=bare_name: (
logger.warning(
"Pre-launched task for %s raised unhandled: %s",
name,
t.exception(),
)
if not t.cancelled() and t.exception()
else None
)
)
if bare_name not in queues:
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
# Store (task, args) so the handler can log a warning if the SDK dispatches
# calls in a different order than the ToolUseBlocks appeared in the message.
queues[bare_name].put_nowait((task, args))
async def _execute_tool_sync(
base_tool: BaseTool,
user_id: str | None,
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion should be performed by the caller before
invoking this function. For the normal (non-parallel) path it is handled
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
handled in :func:`pre_launch_tool_call` before the task is created.
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -319,7 +455,83 @@ def create_tool_handler(base_tool: BaseTool):
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
"""Execute the wrapped tool and return MCP-formatted response.
If a pre-launched task exists (from parallel tool pre-launch in the
message loop), await it instead of executing fresh.
"""
queues = _tool_task_queues.get()
if queues and base_tool.name in queues:
queue = queues[base_tool.name]
if not queue.empty():
task, launch_args = queue.get_nowait()
# Sanity-check: warn if the args don't match — this can happen
# if the SDK dispatches tool calls in a different order than the
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
# could occur in future SDK versions or with SDK bugs).
# We compare full values (not just keys) so that two run_block
# calls with different block_id values are caught even though
# both have the same key set.
if launch_args != args:
logger.warning(
"Pre-launched task for %s: arg mismatch "
"(launch_keys=%s, call_keys=%s) — cancelling "
"pre-launched task and falling back to direct execution",
base_tool.name,
(
sorted(launch_args.keys())
if isinstance(launch_args, dict)
else type(launch_args).__name__
),
(
sorted(args.keys())
if isinstance(args, dict)
else type(args).__name__
),
)
if not task.done():
task.cancel()
# Await cancellation to prevent duplicate concurrent
# execution for blocks with side effects.
try:
await task
except (asyncio.CancelledError, Exception):
pass
# Fall through to the direct-execution path below.
else:
# Args match — await the pre-launched task.
try:
result = await task
except asyncio.CancelledError:
# Re-raise: CancelledError may be propagating from the
# outer streaming loop being cancelled — swallowing it
# would mask the cancellation and prevent proper cleanup.
logger.warning(
"Pre-launched tool %s was cancelled — re-raising",
base_tool.name,
)
raise
except Exception as e:
logger.error(
"Pre-launched tool %s failed: %s",
base_tool.name,
e,
exc_info=True,
)
return _mcp_error(
f"Failed to execute {base_tool.name}. "
"Check server logs for details."
)
# Pre-truncate the result so the _truncating wrapper (which
# wraps this handler) receives an already-within-budget
# value. _truncating handles stashing — we must NOT stash
# here or the output will be appended twice to the FIFO
# queue and pop_pending_tool_output would return a duplicate
# entry on the second call for the same tool.
return truncate(result, _MCP_MAX_CHARS)
# No pre-launched task — execute directly (fallback for non-parallel calls).
user_id, session = get_execution_context()
if session is None:
@@ -436,19 +648,9 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
)
_PARALLEL_ANNOTATION = ToolAnnotations(readOnlyHint=True)
def create_copilot_mcp_server(*, use_e2b: bool = False):
"""Create an in-process MCP server configuration for CoPilot tools.
All tools are annotated with ``readOnlyHint=True`` so the SDK CLI
dispatches concurrent tool calls in parallel rather than sequentially.
This is a deliberate override: even side-effect tools use the hint
because the MCP tools are already individually sandboxed and the
pre-launch duplicate-execution bug (SECRT-2204) is worse than
sequential dispatch.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
@@ -466,28 +668,6 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
# Empty tool args = model's output was truncated by the API's
# max_tokens limit. Instead of letting the tool fail with a
# confusing error (and eventually tripping the circuit breaker),
# return clear guidance so the model can self-correct.
if not args and input_schema and input_schema.get("required"):
logger.warning(
"[MCP] %s called with empty args (likely output "
"token truncation) — returning guidance",
tool_name,
)
return _mcp_error(
f"Your call to {tool_name} had empty arguments — "
f"this means your previous response was too long and "
f"the tool call input was truncated by the API. "
f"To fix this: break your work into smaller steps. "
f"For large content, first write it to a file using "
f"bash_exec with cat >> (append section by section), "
f"then pass it via @@agptfile:filename reference. "
f"Do NOT retry with the same approach — it will "
f"be truncated again."
)
# Circuit breaker: stop infinite retry loops with identical args.
# Use the original (pre-expansion) args for fingerprinting so
# check and record always use the same key — @@agptfile:
@@ -538,35 +718,24 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
schema = _build_input_schema(base_tool)
# All tools annotated readOnlyHint=True to enable parallel dispatch.
# The SDK CLI uses this hint to dispatch concurrent tool calls in
# parallel rather than sequentially. Side-effect safety is ensured
# by the tool implementations themselves (idempotency, credentials).
decorated = tool(
tool_name,
base_tool.description,
schema,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(handler, tool_name, input_schema=schema))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(
name,
desc,
schema,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(handler, name))
decorated = tool(name, desc, schema)(_truncating(handler, name))
sdk_tools.append(decorated)
# Read tool for SDK-truncated tool results (always needed, read-only).
# Read tool for SDK-truncated tool results (always needed).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
annotations=_PARALLEL_ANNOTATION,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)

View File

@@ -1,21 +1,22 @@
"""Tests for tool_adapter: truncation, stash, context vars, readOnlyHint annotations."""
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mcp.types import ToolAnnotations
from backend.copilot.context import get_sdk_cwd
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.sdk.file_ref import FileRefExpansionError
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
SDK_DISALLOWED_TOOLS,
_text_from_mcp_result,
cancel_pending_tool_tasks,
create_tool_handler,
pop_pending_tool_output,
pre_launch_tool_call,
reset_stash_event,
set_execution_context,
stash_pending_tool_output,
@@ -243,7 +244,7 @@ class TestTruncationAndStashIntegration:
# ---------------------------------------------------------------------------
# create_tool_handler (direct execution, no pre-launch)
# Parallel pre-launch infrastructure
# ---------------------------------------------------------------------------
@@ -276,18 +277,169 @@ def _init_ctx(session=None):
)
class TestCreateToolHandler:
"""Tests for create_tool_handler — direct tool execution."""
class TestPreLaunchToolCall:
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_executes_tool_directly(self):
"""Handler executes the tool and returns MCP-formatted result."""
async def test_unknown_tool_is_silently_ignored(self):
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
# Should not raise even if the tool name is completely unknown
await pre_launch_tool_call("nonexistent_tool", {})
@pytest.mark.asyncio
async def test_mcp_prefix_stripped_before_registry_lookup(self):
"""mcp__copilot__run_block is looked up as 'run_block'."""
mock_tool = _make_mock_tool("run_block")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
# The task was enqueued — mock_tool.execute should be called once
# (may not complete immediately but should start)
await asyncio.sleep(0) # yield to event loop
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_bare_tool_name_without_prefix(self):
"""Tool names without __ separator are looked up as-is."""
mock_tool = _make_mock_tool("run_block")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_task_enqueued_fifo_for_same_tool(self):
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
results = []
async def slow_execute(*args, **kwargs):
results.append(len(results))
return StreamToolOutputAvailable(
toolCallId="id",
output=str(len(results) - 1),
toolName="t",
success=True,
)
mock_tool = _make_mock_tool("t")
mock_tool.execute = AsyncMock(side_effect=slow_execute)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"t": mock_tool},
):
await pre_launch_tool_call("t", {"n": 1})
await pre_launch_tool_call("t", {"n": 2})
await asyncio.sleep(0)
assert mock_tool.execute.await_count == 2
@pytest.mark.asyncio
async def test_file_ref_expansion_failure_skips_pre_launch(self):
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
The handler should then fall back to direct execution (which will also
fail with a proper MCP error via _truncating's own expansion).
"""
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
with (
patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
),
patch(
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
),
):
# Should not raise — expansion failure is handled gracefully
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
await asyncio.sleep(0)
# No task was pre-launched — execute was not called
mock_tool.execute.assert_not_awaited()
class TestCreateToolHandlerParallel:
"""Tests for create_tool_handler using pre-launched tasks."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_uses_prelaunched_task(self):
"""Handler pops and awaits the pre-launched task rather than re-executing."""
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0) # let task start
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
text = result["content"][0]["text"]
assert "pre-launched result" in text
# Should only have been called once (the pre-launched task), not twice
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
"""Pre-launched task result must NOT be stashed by tool_handler directly.
The _truncating wrapper wraps tool_handler and handles stashing after
tool_handler returns. If tool_handler also stashed, the output would be
appended twice to the FIFO queue and pop_pending_tool_output would return
a duplicate on the second call.
This test calls tool_handler directly (without _truncating) and asserts
that nothing was stashed — confirming stashing is deferred to _truncating.
"""
mock_tool = _make_mock_tool("run_block", output="stash-me")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
assert "stash-me" in result["content"][0]["text"]
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
# Calling pop here (without going through _truncating) should return None.
not_stashed = pop_pending_tool_output("run_block")
assert not_stashed is None, (
"tool_handler must not stash directly — _truncating handles stashing "
"to prevent double-stash in the FIFO queue"
)
@pytest.mark.asyncio
async def test_handler_falls_back_when_queue_empty(self):
"""When no pre-launched task exists, handler executes directly."""
mock_tool = _make_mock_tool("run_block", output="direct result")
# Don't call pre_launch_tool_call — queue is empty
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
@@ -297,9 +449,104 @@ class TestCreateToolHandler:
mock_tool.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_handler_returns_error_on_no_session(self):
"""When session is None, handler returns MCP error."""
async def test_handler_cancelled_error_propagates(self):
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
with pytest.raises(asyncio.CancelledError):
await handler({"block_id": "b1"})
@pytest.mark.asyncio
async def test_handler_exception_returns_mcp_error(self):
"""Exception from a pre-launched task is caught and returned as MCP error."""
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is True
assert "Failed to execute run_block" in result["content"][0]["text"]
@pytest.mark.asyncio
async def test_two_same_tool_calls_dispatched_in_order(self):
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
call_order = []
async def execute_with_tag(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_order.append(tag)
return StreamToolOutputAvailable(
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "first"})
await pre_launch_tool_call("run_block", {"block_id": "second"})
await asyncio.sleep(0)
handler = create_tool_handler(mock_tool)
r1 = await handler({"block_id": "first"})
r2 = await handler({"block_id": "second"})
assert "out-first" in r1["content"][0]["text"]
assert "out-second" in r2["content"][0]["text"]
assert call_order == [
"first",
"second",
], f"Expected FIFO dispatch order but got {call_order}"
@pytest.mark.asyncio
async def test_arg_mismatch_falls_back_to_direct_execution(self):
"""When pre-launched args differ from SDK args, handler cancels pre-launched
task and falls back to direct execution with the correct args."""
mock_tool = _make_mock_tool("run_block", output="direct-result")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
# Pre-launch with args {"block_id": "wrong"}
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
await asyncio.sleep(0)
# SDK dispatches with different args
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "correct"})
assert result["isError"] is False
# The tool was called twice: once by pre-launch (wrong args), once by
# direct fallback (correct args). The result should come from the
# direct execution path.
assert mock_tool.execute.await_count == 2
@pytest.mark.asyncio
async def test_no_session_falls_back_gracefully(self):
"""When session is None and no pre-launched task, handler returns MCP error."""
mock_tool = _make_mock_tool("run_block")
# session=None means get_execution_context returns (user_id, None)
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
handler = create_tool_handler(mock_tool)
@@ -308,314 +555,220 @@ class TestCreateToolHandler:
assert result["isError"] is True
assert "session" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# cancel_pending_tool_tasks
# ---------------------------------------------------------------------------
class TestCancelPendingToolTasks:
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_returns_error_on_exception(self):
"""Exception from tool execution is caught and returned as MCP error."""
async def test_cancels_queued_tasks(self):
"""Queued tasks are cancelled and the queue is cleared."""
ran = False
async def never_run(*_args, **_kwargs):
nonlocal ran
await asyncio.sleep(10) # long enough to still be pending
ran = True
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
mock_tool.execute = AsyncMock(side_effect=never_run)
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0) # let task start
await cancel_pending_tool_tasks()
await asyncio.sleep(0) # let cancellation propagate
assert result["isError"] is True
assert "Failed to execute run_block" in result["content"][0]["text"]
assert not ran, "Task should have been cancelled before completing"
@pytest.mark.asyncio
async def test_handler_executes_once_per_call(self):
"""Each handler call executes the tool exactly once — no duplicate execution."""
mock_tool = _make_mock_tool("run_block", output="single-execution")
async def test_noop_when_no_tasks_queued(self):
"""cancel_pending_tool_tasks does not raise when queues are empty."""
await cancel_pending_tool_tasks() # should not raise
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1"})
await handler({"block_id": "b2"})
@pytest.mark.asyncio
async def test_handler_does_not_find_cancelled_task(self):
"""After cancel, tool_handler falls back to direct execution."""
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"run_block": mock_tool},
):
await pre_launch_tool_call("run_block", {"block_id": "b1"})
await asyncio.sleep(0)
await cancel_pending_tool_tasks()
# Queue is now empty — handler should execute directly
handler = create_tool_handler(mock_tool)
result = await handler({"block_id": "b1"})
assert result["isError"] is False
assert "direct-fallback" in result["content"][0]["text"]
# ---------------------------------------------------------------------------
# Concurrent / parallel pre-launch scenarios
# ---------------------------------------------------------------------------
class TestAllParallelToolsPrelaunchedIndependently:
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_all_parallel_tools_prelaunched_independently(self):
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
roughly PER_TASK_S (plus scheduling overhead).
"""
PER_TASK_S = 0.05
N = 5
started: list[int] = []
finished: list[int] = []
async def slow_execute(*args, **kwargs):
idx = len(started)
started.append(idx)
await asyncio.sleep(PER_TASK_S)
finished.append(idx)
return StreamToolOutputAvailable(
toolCallId=f"id-{idx}",
output=f"result-{idx}",
toolName="bash_exec",
success=True,
)
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=slow_execute)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
for i in range(N):
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
# Measure only the concurrent execution window, not pre-launch overhead.
# Starting the timer here avoids false failures on slow CI runners where
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
t0 = asyncio.get_running_loop().time()
await asyncio.sleep(PER_TASK_S * 2)
elapsed = asyncio.get_running_loop().time() - t0
assert mock_tool.execute.await_count == N
assert len(finished) == N
# Wall time of the sleep window should be well under N * PER_TASK_S
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
assert elapsed < N * PER_TASK_S, (
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
f"but sleep window took {elapsed:.2f}s"
)
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
"""Pop pre-launched tasks in order and verify each returns its own result."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_handler_returns_result_from_correct_prelaunched_task(self):
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
async def execute_with_cmd(*args, **kwargs):
cmd = kwargs.get("cmd", "?")
return StreamToolOutputAvailable(
toolCallId="id",
output=f"output-for-{cmd}",
toolName="bash_exec",
success=True,
)
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
await asyncio.sleep(0) # let both tasks start
handler = create_tool_handler(mock_tool)
r1 = await handler({"cmd": "alpha"})
r2 = await handler({"cmd": "beta"})
text1 = r1["content"][0]["text"]
text2 = r2["content"][0]["text"]
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
assert mock_tool.execute.await_count == 2
# ---------------------------------------------------------------------------
# Regression tests: bugs fixed by removing pre-launch mechanism
#
# Each test class includes a _buggy_handler fixture that reproduces the old
# pre-launch implementation inline. Tests run against BOTH the buggy handler
# (xfail — proves the bug exists) and the current clean handler (must pass).
# ---------------------------------------------------------------------------
def _make_execute_fn(tool_name: str = "run_block"):
"""Return (execute_fn, call_log) — execute_fn records every call."""
call_log: list[dict] = []
async def execute_fn(*args, **kwargs):
call_log.append(kwargs)
return StreamToolOutputAvailable(
toolCallId=f"id-{len(call_log)}",
output=f"result-{len(call_log)}",
toolName=tool_name,
success=True,
)
return execute_fn, call_log
async def _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args):
"""Simulate the OLD buggy pre-launch flow.
1. pre_launch_tool_call fires _execute_tool_sync with pre_launch_args
2. SDK dispatches handler with dispatch_args
3. Handler compares args — on mismatch, cancels + re-executes (BUG)
Returns the handler result.
"""
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
user_id, session = "user-1", _make_mock_session()
# Step 1: pre-launch fires immediately (speculative)
task = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, pre_launch_args)
)
await asyncio.sleep(0) # let task start
# Step 2: SDK dispatches with (potentially different) args
if pre_launch_args != dispatch_args:
# Arg mismatch path: cancel pre-launched task + re-execute
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
# Fall through to direct execution (duplicate!)
return await _execute_tool_sync(mock_tool, user_id, session, dispatch_args)
else:
return await task
class TestBug1DuplicateExecution:
"""Bug 1 (SECRT-2204): arg mismatch causes duplicate execution.
Pre-launch fires with raw args, SDK dispatches with normalised args.
Mismatch → cancel (too late) + re-execute → 2 API calls.
"""
class TestFiveConcurrentPrelaunchAllComplete:
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old pre-launch code causes duplicate execution")
@pytest.mark.asyncio
async def test_old_code_duplicates_on_arg_mismatch(self):
"""OLD CODE: pre-launch with args A, dispatch with args B → 2 calls."""
execute_fn, call_log = _make_execute_fn()
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_fn)
async def test_five_concurrent_prelaunch_all_complete(self):
"""All 5 pre-launched tasks complete and return successful results."""
N = 5
call_count = 0
pre_launch_args = {"block_id": "b1", "input_data": {"title": "Test"}}
dispatch_args = {
"block_id": "b1",
"input_data": {"title": "Test", "priority": None},
}
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
# BUG: pre-launch executed once + fallback executed again = 2
assert len(call_log) == 1, (
f"Expected 1 execution but got {len(call_log)}"
f"duplicate execution bug!"
)
@pytest.mark.asyncio
async def test_current_code_no_duplicate(self):
"""FIXED: handler executes exactly once regardless of arg shape."""
execute_fn, call_log = _make_execute_fn()
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_fn)
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1", "input_data": {"title": "Test"}})
assert len(call_log) == 1, f"Expected 1 execution but got {len(call_log)}"
class TestBug2FIFODesync:
"""Bug 2: FIFO desync when security hook denies a tool.
Pre-launch queues [task_A, task_B]. Tool A denied (no MCP dispatch).
Tool B's handler dequeues task_A → returns wrong result.
"""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old FIFO queue returns wrong result on denial")
@pytest.mark.asyncio
async def test_old_code_fifo_desync_on_denial(self):
"""OLD CODE: denied tool's task stays in queue, next tool gets wrong result."""
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
call_log: list[str] = []
async def tagged_execute(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_log.append(tag)
async def counting_execute(*args, **kwargs):
nonlocal call_count
call_count += 1
n = call_count
return StreamToolOutputAvailable(
toolCallId="id",
output=f"result-for-{tag}",
toolName="run_block",
toolCallId=f"id-{n}",
output=f"done-{n}",
toolName="bash_exec",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
user_id, session = "user-1", _make_mock_session()
mock_tool = _make_mock_tool("bash_exec")
mock_tool.execute = AsyncMock(side_effect=counting_execute)
# Simulate old FIFO queue
queue: asyncio.Queue = asyncio.Queue()
with patch(
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
{"bash_exec": mock_tool},
):
for i in range(N):
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
# Pre-launch for tool A and tool B
task_a = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "A"})
)
task_b = asyncio.create_task(
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "B"})
)
queue.put_nowait(task_a)
queue.put_nowait(task_b)
await asyncio.sleep(0) # let both tasks run
await asyncio.sleep(0) # let all tasks start
# Tool A is DENIED by security hook — no MCP dispatch, no dequeue
# Tool B's handler dequeues from FIFO → gets task_A!
dequeued_task = queue.get_nowait()
result = await dequeued_task
result_text = result["content"][0]["text"]
handler = create_tool_handler(mock_tool)
results = []
for i in range(N):
results.append(await handler({"cmd": f"task-{i}"}))
# BUG: handler for B got task_A's result
assert "result-for-B" in result_text, (
f"Expected result for B but got: {result_text}"
f"FIFO desync: B got A's result!"
)
@pytest.mark.asyncio
async def test_current_code_no_fifo_desync(self):
"""FIXED: each handler call executes independently, no shared queue."""
call_log: list[str] = []
async def tagged_execute(*args, **kwargs):
tag = kwargs.get("block_id", "?")
call_log.append(tag)
return StreamToolOutputAvailable(
toolCallId="id",
output=f"result-for-{tag}",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
handler = create_tool_handler(mock_tool)
# Tool A denied (never called). Tool B dispatched normally.
result_b = await handler({"block_id": "B"})
assert "result-for-B" in result_b["content"][0]["text"]
assert call_log == ["B"]
class TestBug3CancelRace:
"""Bug 3: cancel race — task completes before cancel arrives.
Pre-launch fires fast HTTP call (< 1s). By the time handler detects
mismatch and calls task.cancel(), the API call already completed.
Side effect (Linear issue created) is irreversible.
"""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.xfail(reason="Old code: cancel arrives after task completes")
@pytest.mark.asyncio
async def test_old_code_cancel_arrives_too_late(self):
"""OLD CODE: fast task completes before cancel, side effect persists."""
side_effects: list[str] = []
async def fast_execute_with_side_effect(*args, **kwargs):
# Side effect happens immediately (like an HTTP POST to Linear)
side_effects.append("created-issue")
return StreamToolOutputAvailable(
toolCallId="id",
output="issue-created",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=fast_execute_with_side_effect)
# Pre-launch fires immediately
pre_launch_args = {"block_id": "b1"}
dispatch_args = {"block_id": "b1", "extra": "normalised"}
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
# BUG: side effect happened TWICE (pre-launch + fallback)
assert len(side_effects) == 1, (
f"Expected 1 side effect but got {len(side_effects)}"
f"cancel race: pre-launch completed before cancel!"
)
@pytest.mark.asyncio
async def test_current_code_single_side_effect(self):
"""FIXED: no speculative execution, exactly 1 side effect per call."""
side_effects: list[str] = []
async def execute_with_side_effect(*args, **kwargs):
side_effects.append("created-issue")
return StreamToolOutputAvailable(
toolCallId="id",
output="issue-created",
toolName="run_block",
success=True,
)
mock_tool = _make_mock_tool("run_block")
mock_tool.execute = AsyncMock(side_effect=execute_with_side_effect)
handler = create_tool_handler(mock_tool)
await handler({"block_id": "b1"})
assert len(side_effects) == 1
# ---------------------------------------------------------------------------
# readOnlyHint annotations
# ---------------------------------------------------------------------------
class TestReadOnlyAnnotations:
"""Tests that all tools get readOnlyHint=True for parallel dispatch."""
def test_parallel_annotation_constant(self):
"""_PARALLEL_ANNOTATION is a ToolAnnotations with readOnlyHint=True."""
from .tool_adapter import _PARALLEL_ANNOTATION
assert isinstance(_PARALLEL_ANNOTATION, ToolAnnotations)
assert _PARALLEL_ANNOTATION.readOnlyHint is True
# ---------------------------------------------------------------------------
# SDK_DISALLOWED_TOOLS
# ---------------------------------------------------------------------------
class TestSDKDisallowedTools:
"""Verify that dangerous SDK built-in tools are in the disallowed list."""
def test_bash_tool_is_disallowed(self):
assert "Bash" in SDK_DISALLOWED_TOOLS
def test_webfetch_tool_is_disallowed(self):
"""WebFetch is disallowed due to SSRF risk."""
assert "WebFetch" in SDK_DISALLOWED_TOOLS
assert (
mock_tool.execute.await_count == N
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
for i, result in enumerate(results):
assert result["isError"] is False, f"Result {i} should not be an error"
text = result["content"][0]["text"]
assert "done-" in text, f"Result {i} missing expected output: {text}"

View File

@@ -43,10 +43,6 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# Thinking block types that can be stripped from non-last assistant entries.
# The Anthropic API only requires these in the *last* assistant message.
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
@dataclass
class TranscriptDownload:
@@ -454,83 +450,6 @@ def _build_meta_storage_path(user_id: str, session_id: str, backend: object) ->
)
def strip_stale_thinking_blocks(content: str) -> str:
"""Remove thinking/redacted_thinking blocks from non-last assistant entries.
The Anthropic API only requires thinking blocks in the **last** assistant
message to be value-identical to the original response. Older assistant
entries carry stale thinking blocks that consume significant tokens
(often 10-50K each) without providing useful context for ``--resume``.
Stripping them before upload prevents the CLI from triggering compaction
every turn just to compress away the stale thinking bloat.
"""
lines = content.strip().split("\n")
if not lines:
return content
parsed: list[tuple[str, dict | None]] = []
for line in lines:
parsed.append((line, json.loads(line, fallback=None)))
# Reverse scan to find the last assistant message ID and index.
last_asst_msg_id: str | None = None
last_asst_idx: int | None = None
for i in range(len(parsed) - 1, -1, -1):
_line, entry = parsed[i]
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_asst_msg_id = msg.get("id")
last_asst_idx = i
break
if last_asst_idx is None:
return content
result_lines: list[str] = []
stripped_count = 0
for i, (line, entry) in enumerate(parsed):
if not isinstance(entry, dict):
result_lines.append(line)
continue
msg = entry.get("message", {})
# Only strip from assistant entries that are NOT the last turn.
# Use msg_id matching when available; fall back to index for entries
# without an id field.
is_last_turn = (
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
) or (last_asst_msg_id is None and i == last_asst_idx)
if (
msg.get("role") == "assistant"
and not is_last_turn
and isinstance(msg.get("content"), list)
):
content_blocks = msg["content"]
filtered = [
b
for b in content_blocks
if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES)
]
if len(filtered) < len(content_blocks):
stripped_count += len(content_blocks) - len(filtered)
entry = {**entry, "message": {**msg, "content": filtered}}
result_lines.append(json.dumps(entry, separators=(",", ":")))
continue
result_lines.append(line)
if stripped_count:
logger.info(
"[Transcript] Stripped %d stale thinking block(s) from non-last entries",
stripped_count,
)
return "\n".join(result_lines) + "\n"
async def upload_transcript(
user_id: str,
session_id: str,
@@ -553,9 +472,6 @@ async def upload_transcript(
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
# Strip stale thinking blocks from older assistant entries — these consume
# significant tokens and trigger unnecessary CLI compaction every turn.
stripped = strip_stale_thinking_blocks(stripped)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
@@ -689,6 +605,9 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
@@ -714,14 +633,11 @@ def _flatten_assistant_content(blocks: list) -> str:
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
# Drop tool_use entirely — any text representation gets
# mimicked by the model as plain text instead of actual
# structured tool calls. The tool results (in the
# following user/tool_result entry) provide sufficient
# context about what happened.
continue
parts.append(f"[tool_use: {block.get('name', '?')}]")
else:
continue
# Preserve non-text blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
parts.append(f"[__{btype}__]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""

View File

@@ -13,7 +13,6 @@ from .transcript import (
delete_transcript,
read_compacted_entries,
strip_progress_entries,
strip_stale_thinking_blocks,
validate_transcript,
write_transcript_to_tempfile,
)
@@ -1201,170 +1200,3 @@ class TestCleanupStaleProjectDirs:
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
assert removed == 0
assert non_copilot.exists()
# ---------------------------------------------------------------------------
# strip_stale_thinking_blocks
# ---------------------------------------------------------------------------
class TestStripStaleThinkingBlocks:
"""Tests for strip_stale_thinking_blocks — removes thinking/redacted_thinking
blocks from non-last assistant entries to reduce transcript bloat."""
def _asst_entry(
self, msg_id: str, content: list, uuid: str = "u1", parent: str = ""
) -> dict:
return {
"type": "assistant",
"uuid": uuid,
"parentUuid": parent,
"message": {
"role": "assistant",
"id": msg_id,
"type": "message",
"content": content,
},
}
def _user_entry(self, text: str, uuid: str = "u0", parent: str = "") -> dict:
return {
"type": "user",
"uuid": uuid,
"parentUuid": parent,
"message": {"role": "user", "content": text},
}
def test_strips_thinking_from_older_assistant(self) -> None:
"""Thinking blocks in non-last assistant entries should be removed."""
old_asst = self._asst_entry(
"msg_old",
[
{"type": "thinking", "thinking": "deep thoughts..."},
{"type": "text", "text": "hello"},
{"type": "redacted_thinking", "data": "secret"},
],
uuid="a1",
)
new_asst = self._asst_entry(
"msg_new",
[
{"type": "thinking", "thinking": "latest thoughts"},
{"type": "text", "text": "world"},
],
uuid="a2",
parent="a1",
)
content = _make_jsonl(old_asst, new_asst)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# Old assistant should have thinking blocks stripped
old_content = lines[0]["message"]["content"]
assert len(old_content) == 1
assert old_content[0]["type"] == "text"
# New (last) assistant should be untouched
new_content = lines[1]["message"]["content"]
assert len(new_content) == 2
assert new_content[0]["type"] == "thinking"
assert new_content[1]["type"] == "text"
def test_preserves_last_assistant_thinking(self) -> None:
"""The last assistant entry's thinking blocks must be preserved."""
entry = self._asst_entry(
"msg_only",
[
{"type": "thinking", "thinking": "must keep"},
{"type": "text", "text": "response"},
],
)
content = _make_jsonl(entry)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(lines[0]["message"]["content"]) == 2
def test_no_assistant_entries_returns_unchanged(self) -> None:
"""Transcripts with only user entries should pass through unchanged."""
user = self._user_entry("hello")
content = _make_jsonl(user)
assert strip_stale_thinking_blocks(content) == content
def test_empty_content_returns_unchanged(self) -> None:
assert strip_stale_thinking_blocks("") == ""
def test_multiple_turns_strips_all_but_last(self) -> None:
"""With 3 assistant turns, only the last keeps thinking blocks."""
entries = [
self._asst_entry(
"msg_1",
[
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "a1"},
],
uuid="a1",
),
self._user_entry("q2", uuid="u2", parent="a1"),
self._asst_entry(
"msg_2",
[
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "a2"},
],
uuid="a2",
parent="u2",
),
self._user_entry("q3", uuid="u3", parent="a2"),
self._asst_entry(
"msg_3",
[
{"type": "thinking", "thinking": "t3"},
{"type": "text", "text": "a3"},
],
uuid="a3",
parent="u3",
),
]
content = _make_jsonl(*entries)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# msg_1: thinking stripped
assert len(lines[0]["message"]["content"]) == 1
assert lines[0]["message"]["content"][0]["type"] == "text"
# msg_2: thinking stripped
assert len(lines[2]["message"]["content"]) == 1
# msg_3 (last): thinking preserved
assert len(lines[4]["message"]["content"]) == 2
assert lines[4]["message"]["content"][0]["type"] == "thinking"
def test_same_msg_id_multi_entry_turn(self) -> None:
"""Multiple entries sharing the same message.id (same turn) are preserved."""
entries = [
self._asst_entry(
"msg_old",
[{"type": "thinking", "thinking": "old"}],
uuid="a1",
),
self._asst_entry(
"msg_last",
[{"type": "thinking", "thinking": "t_part1"}],
uuid="a2",
parent="a1",
),
self._asst_entry(
"msg_last",
[{"type": "text", "text": "response"}],
uuid="a3",
parent="a2",
),
]
content = _make_jsonl(*entries)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
# Old entry stripped
assert lines[0]["message"]["content"] == []
# Both entries of last turn (msg_last) preserved
assert lines[1]["message"]["content"][0]["type"] == "thinking"
assert lines[2]["message"]["content"][0]["type"] == "text"

View File

@@ -30,7 +30,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id, dry_run=False)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -221,21 +221,9 @@ async def create_session(
return session
_meta_ttl_refresh_at: dict[str, float] = {}
"""Tracks the last time the session meta key TTL was refreshed.
Used by `publish_chunk` to avoid refreshing on every single chunk
(expensive). Refreshes at most once every 60 seconds per session.
"""
_META_TTL_REFRESH_INTERVAL = 60 # seconds
async def publish_chunk(
turn_id: str,
chunk: StreamBaseResponse,
*,
session_id: str | None = None,
) -> str:
"""Publish a chunk to Redis Stream.
@@ -244,9 +232,6 @@ async def publish_chunk(
Args:
turn_id: Turn ID (per-turn UUID) identifying the stream
chunk: The stream response chunk to publish
session_id: Chat session ID — when provided, the session meta key
TTL is refreshed periodically to prevent expiration during
long-running turns (see SECRT-2178).
Returns:
The Redis Stream message ID
@@ -280,23 +265,6 @@ async def publish_chunk(
# Set TTL on stream to match session metadata TTL
await redis.expire(stream_key, config.stream_ttl)
# Periodically refresh session-related TTLs so they don't expire
# during long-running turns. Without this, turns exceeding stream_ttl
# (default 1h) lose their "running" status and stream data, making
# the session invisible to the resume endpoint (empty on page reload).
# Both meta key AND stream key are refreshed: the stream key's expire
# above only fires when publish_chunk is called, but during long
# sub-agent gaps (task_progress events don't produce chunks), neither
# key gets refreshed.
if session_id:
now = time.perf_counter()
last_refresh = _meta_ttl_refresh_at.get(session_id, 0)
if now - last_refresh >= _META_TTL_REFRESH_INTERVAL:
meta_key = _get_session_meta_key(session_id)
await redis.expire(meta_key, config.stream_ttl)
await redis.expire(stream_key, config.stream_ttl)
_meta_ttl_refresh_at[session_id] = now
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
@@ -363,7 +331,7 @@ async def stream_and_publish(
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event, session_id=session_id)
await publish_chunk(turn_id, event)
except (RedisError, ConnectionError, OSError):
if not publish_failed_once:
publish_failed_once = True
@@ -832,9 +800,6 @@ async def mark_session_completed(
# Atomic compare-and-swap: only update if status is "running"
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
_meta_ttl_refresh_at.pop(session_id, None)
if result == 0:
logger.debug(f"Session {session_id} already completed/failed, skipping")
return False

View File

@@ -68,9 +68,6 @@ class AddUnderstandingTool(BaseTool):
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
Note: This tool accepts **kwargs because its parameters are derived
dynamically from the BusinessUnderstandingInput model schema.
"""
session_id = session.session_id
@@ -80,21 +77,23 @@ class AddUnderstandingTool(BaseTool):
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
filtered = {k: v for k, v in kwargs.items() if k in valid_fields}
# Check if any data was provided
if not any(v is not None for v in filtered.values()):
if not any(v is not None for v in kwargs.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
input_data = BusinessUnderstandingInput(**filtered)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
input_data = BusinessUnderstandingInput(
**{k: v for k, v in kwargs.items() if k in valid_fields}
)
# Track which fields were updated
updated_fields = [k for k, v in filtered.items() if v is not None]
updated_fields = [
k for k, v in kwargs.items() if k in valid_fields and v is not None
]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(

View File

@@ -180,14 +180,12 @@ async def _save_browser_state(
"""
try:
# Gather state in parallel
(
(rc_url, url_out, _),
(rc_ck, ck_out, _),
(rc_ls, ls_out, _),
) = await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
@@ -450,8 +448,6 @@ class BrowserNavigateTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
url: str = "",
wait_for: str = "networkidle",
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
@@ -460,8 +456,8 @@ class BrowserNavigateTool(BaseTool):
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url = url.strip()
wait_for = wait_for or "networkidle"
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
@@ -616,10 +612,6 @@ class BrowserActTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
action: str = "",
target: str = "",
value: str = "",
direction: str = "down",
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
@@ -628,10 +620,10 @@ class BrowserActTool(BaseTool):
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action = action.strip()
target = target.strip()
value = value.strip()
direction = direction.strip()
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
@@ -785,8 +777,6 @@ class BrowserScreenshotTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
annotate: bool | str = True,
filename: str = "screenshot.png",
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
@@ -796,12 +786,12 @@ class BrowserScreenshotTool(BaseTool):
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = annotate
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename = filename.strip()
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod

View File

@@ -411,12 +411,7 @@ class AgentOutputTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool.
Note: This tool accepts **kwargs and delegates to AgentOutputInput
for validation because the parameter set has cross-field validators
defined in the Pydantic model.
"""
"""Execute the agent_output tool."""
session_id = session.session_id
# Parse and validate input

View File

@@ -76,8 +76,6 @@ class BashExecTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
command: str = "",
timeout: int = 30,
**kwargs: Any,
) -> ToolResponseBase:
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
@@ -90,8 +88,8 @@ class BashExecTool(BaseTool):
"""
session_id = session.session_id if session else None
command = command.strip()
timeout = int(timeout)
command: str = (kwargs.get("command") or "").strip()
timeout: int = int(kwargs.get("timeout", 30))
if not command:
return ErrorResponse(

View File

@@ -115,9 +115,6 @@ class ConnectIntegrationTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
provider: str = "",
reason: str = "",
scopes: list[str] | None = None,
**kwargs: Any,
) -> ToolResponseBase:
"""Build and return a :class:`SetupRequirementsResponse` for the requested provider.
@@ -131,10 +128,12 @@ class ConnectIntegrationTool(BaseTool):
"""
_ = user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider = (provider or "").strip().lower()
reason = (reason or "").strip()[:500] # cap LLM-controlled text
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (scopes or []) if str(s).strip()
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
entry = SUPPORTED_PROVIDERS.get(provider)
@@ -142,7 +141,8 @@ class ConnectIntegrationTool(BaseTool):
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDERS)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. Supported providers: {supported}."
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
@@ -153,11 +153,11 @@ class ConnectIntegrationTool(BaseTool):
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = entry["default_scopes"]
seen: set[str] = set()
merged_scopes: list[str] = []
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
merged_scopes.append(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
@@ -171,7 +171,7 @@ class ConnectIntegrationTool(BaseTool):
"title": f"{display_name} Credentials",
"provider": provider,
"types": supported_types,
"scopes": merged_scopes,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}

View File

@@ -53,10 +53,11 @@ class ContinueRunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
review_id: str = "",
**kwargs,
) -> ToolResponseBase:
review_id = review_id.strip() if review_id else ""
review_id = (
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
)
session_id = session.session_id
if not review_id:

View File

@@ -62,12 +62,9 @@ class CreateAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json:
@@ -80,8 +77,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
if library_agent_ids is None:
library_agent_ids = []
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:

View File

@@ -61,12 +61,9 @@ class CustomizeAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json:
@@ -78,8 +75,9 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
if library_agent_ids is None:
library_agent_ids = []
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:

View File

@@ -62,15 +62,10 @@ class EditAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_id: str = "",
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
**kwargs,
) -> ToolResponseBase:
agent_id = agent_id.strip()
if library_agent_ids is None:
library_agent_ids = []
agent_id = kwargs.get("agent_id", "").strip()
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
@@ -89,6 +84,9 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(

View File

@@ -157,10 +157,9 @@ class SearchFeatureRequestsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
) -> ToolResponseBase:
query = (query or "").strip()
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:
@@ -289,13 +288,11 @@ class CreateFeatureRequestTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
title: str = "",
description: str = "",
existing_issue_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
title = (title or "").strip()
description = (description or "").strip()
title = kwargs.get("title", "").strip()
description = kwargs.get("description", "").strip()
existing_issue_id = kwargs.get("existing_issue_id")
session_id = session.session_id if session else None
if not title or not description:

View File

@@ -34,15 +34,11 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Search marketplace for agents matching the query."""
return await search_agents(
query=query.strip(),
query=kwargs.get("query", "").strip(),
source="marketplace",
session_id=session.session_id,
user_id=user_id,

View File

@@ -86,8 +86,6 @@ class FindBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
include_schemas: bool = False,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
@@ -96,14 +94,14 @@ class FindBlockTool(BaseTool):
user_id: User ID (required)
session: Chat session
query: Search query
include_schemas: Whether to include block schemas in results
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = (query or "").strip()
query = kwargs.get("query", "").strip()
include_schemas = kwargs.get("include_schemas", False)
session_id = session.session_id
if not query:

View File

@@ -41,14 +41,10 @@ class FindLibraryAgentTool(BaseTool):
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=query.strip(),
query=(kwargs.get("query") or "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -51,9 +51,9 @@ class FixAgentGraphTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
@@ -98,7 +98,8 @@ class FixAgentGraphTool(BaseTool):
if is_valid:
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es). Agent graph is now valid!"
f"Applied {len(fixes_applied)} fix(es). "
"Agent graph is now valid!"
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,

View File

@@ -60,7 +60,7 @@ class GetAgentBuildingGuideTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:

View File

@@ -68,7 +68,6 @@ class GetDocPageTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
path: str = "",
**kwargs,
) -> ToolResponseBase:
"""Fetch full content of a documentation page.
@@ -82,7 +81,7 @@ class GetDocPageTool(BaseTool):
DocPageResponse: Full document content
ErrorResponse: Error message
"""
path = path.strip()
path = kwargs.get("path", "").strip()
session_id = session.session_id if session else None
if not path:

View File

@@ -56,7 +56,7 @@ class GetMCPGuideTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:

View File

@@ -81,7 +81,7 @@ async def execute_block(
node_exec_id: str,
matched_credentials: dict[str, CredentialsMetaInput],
sensitive_action_safe_mode: bool = False,
dry_run: bool,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute a block with full context setup, credential injection, and error handling.
@@ -114,9 +114,11 @@ async def execute_block(
error=sim_error[0],
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
message=(
f"[DRY RUN] Block '{block.name}' simulated successfully "
"— no real execution occurred."
),
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
@@ -335,7 +337,7 @@ async def prepare_block_for_execution(
user_id: str,
session: ChatSession,
session_id: str,
dry_run: bool,
dry_run: bool = False,
) -> "BlockPreparation | ToolResponseBase":
"""Validate and prepare a block for execution.

View File

@@ -102,7 +102,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, BlockOutputResponse)
@@ -133,7 +132,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, ErrorResponse)
@@ -160,7 +158,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, BlockOutputResponse)
@@ -197,7 +194,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
# Block already executed (with side effects), so output is returned
@@ -281,7 +277,6 @@ async def test_coerce_json_string_to_nested_list():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -322,7 +317,6 @@ async def test_coerce_json_string_to_list():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-2",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -355,7 +349,6 @@ async def test_coerce_json_string_to_dict():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-3",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -389,7 +382,6 @@ async def test_no_coercion_when_type_matches():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-4",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -423,7 +415,6 @@ async def test_coerce_string_to_int():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-5",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -457,7 +448,6 @@ async def test_coerce_skips_none_values():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-6",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -491,7 +481,6 @@ async def test_coerce_union_type_preserves_valid_member():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-7",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -527,7 +516,6 @@ async def test_coerce_inner_elements_of_generic():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-8",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -604,7 +592,6 @@ async def test_prepare_block_not_found() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "not found" in result.message
@@ -625,7 +612,6 @@ async def test_prepare_block_disabled() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "disabled" in result.message
@@ -654,7 +640,6 @@ async def test_prepare_block_unrecognized_fields() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, InputValidationErrorResponse)
assert "unknown_field" in result.unrecognized_fields
@@ -684,7 +669,6 @@ async def test_prepare_block_missing_credentials() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, SetupRequirementsResponse)
@@ -714,7 +698,6 @@ async def test_prepare_block_success_returns_preparation() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, BlockPreparation)
assert result.required_non_credential_keys == {"text"}
@@ -819,7 +802,6 @@ async def test_prepare_block_excluded_by_type() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "cannot be run directly" in result.message
@@ -842,7 +824,6 @@ async def test_prepare_block_excluded_by_id() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "cannot be run directly" in result.message
@@ -876,7 +857,6 @@ async def test_prepare_block_file_ref_expansion_error() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "file reference" in result.message.lower()

View File

@@ -866,7 +866,6 @@ class TestRunBlockToolAuthenticatedHttp:
session=session,
block_id=block.id,
input_data={"url": "https://api.example.com/data", "method": "GET"},
dry_run=False,
)
assert isinstance(response, SetupRequirementsResponse)
@@ -908,7 +907,6 @@ class TestRunBlockToolAuthenticatedHttp:
session=session,
block_id=block.id,
input_data={},
dry_run=False,
)
assert isinstance(response, BlockDetailsResponse)

View File

@@ -120,18 +120,14 @@ class CreateFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
name: str = "",
parent_id: str | None = None,
icon: str | None = None,
color: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Create a folder with the given name and optional parent/icon/color."""
assert user_id is not None # guaranteed by requires_auth
name = (name or "").strip()
name = (kwargs.get("name") or "").strip()
parent_id = kwargs.get("parent_id")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not name:
@@ -200,15 +196,12 @@ class ListFoldersTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
parent_id: str | None = None,
include_agents: bool = False,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""List folders as a flat list (by parent) or full tree."""
assert user_id is not None # guaranteed by requires_auth
parent_id = kwargs.get("parent_id")
include_agents = kwargs.get("include_agents", False)
session_id = session.session_id if session else None
try:
@@ -300,18 +293,14 @@ class UpdateFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
name: str | None = None,
icon: str | None = None,
color: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Update a folder's name, icon, or color."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
name = kwargs.get("name")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not folder_id:
@@ -376,16 +365,12 @@ class MoveFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
target_parent_id: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move a folder to a new parent or to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
target_parent_id = kwargs.get("target_parent_id")
session_id = session.session_id if session else None
if not folder_id:
@@ -446,15 +431,11 @@ class DeleteFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Soft-delete a folder; agents inside are moved to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
session_id = session.session_id if session else None
if not folder_id:
@@ -518,17 +499,12 @@ class MoveAgentsToFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move one or more agents to a folder or to root level."""
assert user_id is not None # guaranteed by requires_auth
if agent_ids is None:
agent_ids = []
agent_ids = kwargs.get("agent_ids", [])
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_ids:

View File

@@ -71,7 +71,7 @@ class RunAgentInput(BaseModel):
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
dry_run: bool
dry_run: bool = False
@field_validator(
"username_agent_slug",
@@ -153,10 +153,14 @@ class RunAgentTool(BaseTool):
},
"dry_run": {
"type": "boolean",
"description": "Execute in preview mode.",
"description": (
"When true, simulates the entire agent execution using an LLM "
"for each block — no real API calls, no credentials needed, "
"no credits charged. Useful for testing agent wiring end-to-end."
),
},
},
"required": ["dry_run"],
"required": [],
}
@property
@@ -170,16 +174,8 @@ class RunAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the tool with automatic state detection.
Note: This tool accepts **kwargs and delegates to RunAgentInput for
validation because the parameter set is complex with cross-field
validators defined in the Pydantic model.
"""
"""Execute the tool with automatic state detection."""
params = RunAgentInput(**kwargs)
# Session-level dry_run forces all tool calls to use dry-run mode.
if session.dry_run:
params.dry_run = True
session_id = session.session_id
# Validate at least one identifier is provided
@@ -205,18 +201,6 @@ class RunAgentTool(BaseTool):
# Determine if this is a schedule request
is_schedule = bool(params.schedule_name or params.cron)
# Session-level dry-run blocks scheduling — schedules create real
# side effects that cannot be simulated.
if params.dry_run and is_schedule:
return ErrorResponse(
message=(
"Scheduling is disabled in dry-run mode because it creates "
"real side effects. Remove cron/schedule_name to simulate "
"a run, or disable dry-run to create a real schedule."
),
session_id=session_id,
)
try:
# Step 1: Fetch agent details
graph: GraphModel | None = None
@@ -474,8 +458,8 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
dry_run: bool,
wait_for_result: int = 0,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
session_id = session.session_id

View File

@@ -53,7 +53,6 @@ async def test_run_agent(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"test_input": "Hello World"},
dry_run=False,
session=session,
)
@@ -94,7 +93,6 @@ async def test_run_agent_missing_inputs(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={}, # Missing required input
dry_run=False,
session=session,
)
@@ -127,7 +125,6 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug="invalid/agent-id",
inputs={"test_input": "Hello World"},
dry_run=False,
session=session,
)
@@ -168,7 +165,6 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"user_prompt": "What is 2+2?"},
dry_run=False,
session=session,
)
@@ -207,7 +203,6 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
username_agent_slug=agent_marketplace_id,
inputs={},
use_defaults=False,
dry_run=False,
session=session,
)
@@ -243,7 +238,6 @@ async def test_run_agent_with_use_defaults(setup_test_data):
username_agent_slug=agent_marketplace_id,
inputs={},
use_defaults=True,
dry_run=False,
session=session,
)
@@ -274,7 +268,6 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"url": "https://example.com"},
dry_run=False,
session=session,
)
@@ -307,7 +300,6 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug="no-slash-here",
inputs={},
dry_run=False,
session=session,
)
@@ -335,7 +327,6 @@ async def test_run_agent_unauthenticated():
tool_call_id=str(uuid.uuid4()),
username_agent_slug="test/test-agent",
inputs={},
dry_run=False,
session=session,
)
@@ -368,7 +359,6 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
inputs={"test_input": "test"},
schedule_name="My Schedule",
cron="", # Empty cron
dry_run=False,
session=session,
)
@@ -401,7 +391,6 @@ async def test_run_agent_schedule_without_name(setup_test_data):
inputs={"test_input": "test"},
schedule_name="", # Empty name
cron="0 9 * * *",
dry_run=False,
session=session,
)
@@ -435,7 +424,6 @@ async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"unknown_field": "some value",
"another_unknown": "another value",
},
dry_run=False,
session=session,
)

View File

@@ -51,10 +51,14 @@ class RunBlockTool(BaseTool):
},
"dry_run": {
"type": "boolean",
"description": "Execute in preview mode.",
"description": (
"When true, simulates block execution using an LLM without making any "
"real API calls or producing side effects. Useful for testing agent "
"wiring and previewing outputs. Default: false."
),
},
},
"required": ["block_id", "input_data", "dry_run"],
"required": ["block_id", "input_data"],
}
@property
@@ -65,10 +69,6 @@ class RunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
*,
block_id: str = "",
input_data: dict | None = None,
dry_run: bool,
**kwargs,
) -> ToolResponseBase:
"""Execute a block with the given input data.
@@ -78,19 +78,15 @@ class RunBlockTool(BaseTool):
session: Chat session
block_id: Block UUID to execute
input_data: Input values for the block
dry_run: If True, simulate execution without side effects
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
block_id = block_id.strip()
if input_data is None:
input_data = {}
# Session-level dry_run forces all tool calls to use dry-run mode.
if session.dry_run:
dry_run = True
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
dry_run = bool(kwargs.get("dry_run", False))
session_id = session.session_id
if not block_id:

View File

@@ -103,7 +103,6 @@ class TestRunBlockFiltering:
session=session,
block_id="input-block-id",
input_data={},
dry_run=False,
)
assert isinstance(response, ErrorResponse)
@@ -130,7 +129,6 @@ class TestRunBlockFiltering:
session=session,
block_id=orchestrator_id,
input_data={},
dry_run=False,
)
assert isinstance(response, ErrorResponse)
@@ -156,7 +154,6 @@ class TestRunBlockFiltering:
session=session,
block_id=block_id,
input_data={},
dry_run=False,
)
finally:
_current_permissions.reset(token)
@@ -190,7 +187,6 @@ class TestRunBlockFiltering:
session=session,
block_id=block_id,
input_data={},
dry_run=False,
)
finally:
_current_permissions.reset(token)
@@ -226,7 +222,6 @@ class TestRunBlockFiltering:
session=session,
block_id="standard-id",
input_data={},
dry_run=False,
)
# Should NOT be an ErrorResponse about CoPilot exclusion
@@ -287,7 +282,6 @@ class TestRunBlockInputValidation:
"prompt": "Write a haiku about coding",
"LLM_Model": "claude-opus-4-6",
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -333,7 +327,6 @@ class TestRunBlockInputValidation:
"system_prompt": "Be helpful",
"retries": 5,
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -377,7 +370,6 @@ class TestRunBlockInputValidation:
input_data={
"LLM_Model": "claude-opus-4-6",
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -432,7 +424,6 @@ class TestRunBlockInputValidation:
"prompt": "Write a haiku",
"model": "gpt-4o-mini",
},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -472,7 +463,6 @@ class TestRunBlockInputValidation:
input_data={
"model": "gpt-4o-mini",
},
dry_run=False,
)
assert isinstance(response, BlockDetailsResponse)
@@ -524,7 +514,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="delete-branch-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, ReviewRequiredResponse)
@@ -585,7 +574,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="delete-branch-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -640,7 +628,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="http-request-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)

View File

@@ -91,40 +91,21 @@ class RunMCPToolTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
server_url: str = "",
tool_name: str = "",
tool_arguments: dict[str, Any] | None = None,
**kwargs,
) -> ToolResponseBase:
server_url = server_url.strip()
tool_name = tool_name.strip()
server_url: str = (kwargs.get("server_url") or "").strip()
tool_name: str = (kwargs.get("tool_name") or "").strip()
raw_tool_arguments = kwargs.get("tool_arguments")
tool_arguments: dict[str, Any] = (
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
)
session_id = session.session_id
# Session-level dry_run prevents real MCP tool execution.
# Discovery (no tool_name) is still allowed so the agent can inspect
# available tools, but actual execution is blocked.
if session.dry_run and tool_name:
return MCPToolOutputResponse(
message=(
f"[dry-run] MCP tool '{tool_name}' on "
f"{server_host(server_url)} was not executed "
"because the session is in dry-run mode."
),
server_url=server_url,
tool_name=tool_name,
result=None,
success=True,
session_id=session_id,
)
if tool_arguments is not None and not isinstance(tool_arguments, dict):
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
return ErrorResponse(
message="tool_arguments must be a JSON object.",
session_id=session_id,
)
resolved_tool_arguments: dict[str, Any] = (
tool_arguments if isinstance(tool_arguments, dict) else {}
)
if not server_url:
return ErrorResponse(
@@ -186,7 +167,7 @@ class RunMCPToolTool(BaseTool):
else:
# Stage 2: Execute the selected tool
return await self._execute_tool(
client, server_url, tool_name, resolved_tool_arguments, session_id
client, server_url, tool_name, tool_arguments, session_id
)
except HTTPClientError as e:

View File

@@ -85,7 +85,6 @@ class SearchDocsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
) -> ToolResponseBase:
"""Search documentation and return relevant sections.
@@ -100,7 +99,7 @@ class SearchDocsTool(BaseTool):
NoResultsResponse: No results found
ErrorResponse: Error message
"""
query = query.strip()
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:

View File

@@ -73,10 +73,7 @@ def make_openai_response(
@pytest.mark.asyncio
async def test_simulate_block_basic():
"""simulate_block returns correct (output_name, output_data) tuples.
Empty "error" pins are dropped at source — only non-empty errors are yielded.
"""
"""simulate_block returns correct (output_name, output_data) tuples."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
@@ -91,8 +88,7 @@ async def test_simulate_block_basic():
outputs.append((name, data))
assert ("result", "simulated output") in outputs
# Empty error pin is dropped at the simulator level
assert ("error", "") not in outputs
assert ("error", "") in outputs
@pytest.mark.asyncio
@@ -117,8 +113,6 @@ async def test_simulate_block_json_retry():
assert mock_client.chat.completions.create.call_count == 3
assert ("result", "ok") in outputs
# Empty error pin is dropped
assert ("error", "") not in outputs
@pytest.mark.asyncio
@@ -147,7 +141,7 @@ async def test_simulate_block_all_retries_exhausted():
@pytest.mark.asyncio
async def test_simulate_block_missing_output_pins():
"""LLM response missing some output pins; verify non-error pins filled with None."""
"""LLM response missing some output pins; verify they're filled with None."""
mock_block = make_mock_block(
output_props={
"result": {"type": "string"},
@@ -170,29 +164,7 @@ async def test_simulate_block_missing_output_pins():
assert outputs["result"] == "hello"
assert outputs["count"] is None # missing pin filled with None
assert "error" not in outputs # missing error pin is omitted entirely
@pytest.mark.asyncio
async def test_simulate_block_keeps_nonempty_error():
"""simulate_block keeps non-empty error pins (simulated logical errors)."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
return_value=make_openai_response(
'{"result": "", "error": "API rate limit exceeded"}'
)
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = []
async for name, data in simulate_block(mock_block, {"query": "test"}):
outputs.append((name, data))
assert ("result", "") in outputs
assert ("error", "API rate limit exceeded") in outputs
assert outputs["error"] == "" # "error" pin filled with ""
@pytest.mark.asyncio
@@ -228,19 +200,6 @@ async def test_simulate_block_truncates_long_inputs():
assert len(parsed["text"]) < 25000
def test_build_simulation_prompt_excludes_error_from_must_include():
"""The 'MUST include' prompt line should NOT list 'error' — the prompt
already instructs the LLM to OMIT error unless simulating a logical error.
Including it in 'MUST include' would be contradictory."""
block = make_mock_block() # default output_props has "result" and "error"
system_prompt, _ = build_simulation_prompt(block, {"query": "test"})
must_include_line = [
line for line in system_prompt.splitlines() if "MUST include" in line
][0]
assert '"result"' in must_include_line
assert '"error"' not in must_include_line
# ---------------------------------------------------------------------------
# execute_block dry-run tests
# ---------------------------------------------------------------------------
@@ -279,7 +238,7 @@ async def test_execute_block_dry_run_skips_real_execution():
@pytest.mark.asyncio
async def test_execute_block_dry_run_response_format():
"""Dry-run response should match real execution message format and have success=True."""
"""Dry-run response should contain [DRY RUN] in message and success=True."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
@@ -300,8 +259,7 @@ async def test_execute_block_dry_run_response_format():
)
assert isinstance(response, BlockOutputResponse)
assert "executed successfully" in response.message
assert "[DRY RUN]" not in response.message # must not leak to LLM context
assert "[DRY RUN]" in response.message
assert response.success is True
assert response.outputs == {"result": ["simulated"]}
@@ -349,24 +307,23 @@ async def test_execute_block_real_execution_unchanged():
def test_run_block_tool_dry_run_param():
"""RunBlockTool parameters should include 'dry_run' as a required field."""
"""RunBlockTool parameters should include 'dry_run'."""
tool = RunBlockTool()
params = tool.parameters
assert "dry_run" in params["properties"]
assert params["properties"]["dry_run"]["type"] == "boolean"
assert "dry_run" in params["required"]
def test_run_block_tool_dry_run_calls_execute():
"""RunBlockTool._execute accepts dry_run as a typed parameter.
"""RunBlockTool._execute extracts dry_run from kwargs correctly.
We verify the parameter exists in the signature and is forwarded to
execute_block.
We verify the extraction logic directly by inspecting the source, then confirm
the kwarg is forwarded in the execute_block call site.
"""
source = inspect.getsource(run_block_module.RunBlockTool._execute)
# Verify dry_run is a typed parameter (not extracted from kwargs)
# Verify dry_run is extracted from kwargs
assert "dry_run" in source
assert "dry_run: bool" in source
assert 'kwargs.get("dry_run"' in source
# Scope to _execute method source only — module-wide search is brittle
# and can match unrelated text/comments.
@@ -375,107 +332,13 @@ def test_run_block_tool_dry_run_calls_execute():
assert "dry_run=dry_run" in source_execute
@pytest.mark.asyncio
async def test_execute_block_dry_run_no_empty_error_from_simulator():
"""The simulator no longer yields empty error pins, so execute_block
simply passes through whatever the simulator produces.
Since the fix is at the simulator level, even if a simulator somehow
yields only non-error outputs, they pass through unchanged.
"""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
# Simulator now omits empty error pins at source
yield "result", "simulated output"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert response.is_dry_run is True
assert "error" not in response.outputs
assert response.outputs == {"result": ["simulated output"]}
@pytest.mark.asyncio
async def test_execute_block_dry_run_keeps_nonempty_error_pin():
"""Dry-run should keep the 'error' pin when it contains a real error message."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
yield "result", ""
yield "error", "API rate limit exceeded"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
# Non-empty error should be preserved
assert "error" in response.outputs
assert response.outputs["error"] == ["API rate limit exceeded"]
@pytest.mark.asyncio
async def test_execute_block_dry_run_message_includes_completed_status():
"""Dry-run message should clearly indicate COMPLETED status."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
yield "result", "simulated"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert "executed successfully" in response.message
@pytest.mark.asyncio
async def test_execute_block_dry_run_simulator_error_returns_error_response():
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
mock_block = make_mock_block()
async def fake_simulate_error(block, input_data):
yield (
"error",
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",
)
yield "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key)."
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate_error

View File

@@ -76,7 +76,6 @@ async def test_run_block_returns_details_when_no_input_provided():
session=session,
block_id="http-block-id",
input_data={}, # Empty input data
dry_run=False,
)
# Should return BlockDetailsResponse showing the schema
@@ -144,7 +143,6 @@ async def test_run_block_returns_details_when_only_credentials_provided():
session=session,
block_id="api-block-id",
input_data={"credentials": {"some": "cred"}}, # Only credential
dry_run=False,
)
# Should return details because no non-credential inputs provided

View File

@@ -151,7 +151,7 @@ async def test_non_dict_tool_arguments_returns_error():
session=session,
server_url=_SERVER_URL,
tool_name="fetch",
tool_arguments=["this", "is", "a", "list"], # type: ignore[arg-type] # intentionally wrong type to test validation
tool_arguments=["this", "is", "a", "list"], # wrong type
)
assert isinstance(response, ErrorResponse)

View File

@@ -1,499 +0,0 @@
"""Tests for session-level dry_run flag propagation.
Verifies that when a session has dry_run=True, run_block, run_agent, and
run_mcp_tool calls are forced to use dry-run mode, regardless of what the
individual tool call specifies. The single source of truth is
``session.dry_run``.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import ErrorResponse, MCPToolOutputResponse
from backend.copilot.tools.run_agent import RunAgentInput, RunAgentTool
from backend.copilot.tools.run_block import RunBlockTool
from backend.copilot.tools.run_mcp_tool import RunMCPToolTool
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(dry_run: bool = False) -> ChatSession:
"""Create a minimal ChatSession for testing."""
session = ChatSession.new("test-user", dry_run=dry_run)
return session
def _make_mock_block(name: str = "TestBlock"):
"""Create a minimal mock block with jsonschema() methods."""
block = MagicMock()
block.name = name
block.description = "A test block"
block.disabled = False
block.block_type = "STANDARD"
block.id = "test-block-id"
block.input_schema = MagicMock()
block.input_schema.jsonschema.return_value = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
block.input_schema.get_credentials_fields.return_value = {}
block.input_schema.get_credentials_fields_info.return_value = {}
block.output_schema = MagicMock()
block.output_schema.jsonschema.return_value = {
"type": "object",
"properties": {"result": {"type": "string"}},
"required": ["result"],
}
return block
# ---------------------------------------------------------------------------
# RunBlockTool tests
# ---------------------------------------------------------------------------
class TestRunBlockToolSessionDryRun:
"""Test that RunBlockTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_forces_block_dry_run(self):
"""When session dry_run is True, run_block should force dry_run=True."""
tool = RunBlockTool()
session = _make_session(dry_run=True)
mock_block = _make_mock_block()
with (
patch(
"backend.copilot.tools.run_block.prepare_block_for_execution"
) as mock_prep,
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
patch(
"backend.copilot.tools.run_block.get_current_permissions",
return_value=None,
),
):
# Set up prepare_block_for_execution to return a mock prep
mock_prep_result = MagicMock()
mock_prep_result.block = mock_block
mock_prep_result.input_data = {"query": "test"}
mock_prep_result.matched_credentials = {}
mock_prep_result.synthetic_node_id = "node-1"
mock_prep.return_value = mock_prep_result
# Set up execute_block to return a success
mock_exec.return_value = MagicMock(
message="Block 'TestBlock' executed successfully",
success=True,
)
await tool._execute(
user_id="test-user",
session=session,
block_id="test-block-id",
input_data={"query": "test"},
dry_run=False, # User passed False, but session overrides
)
# Verify execute_block was called with dry_run=True
mock_exec.assert_called_once()
call_kwargs = mock_exec.call_args
assert call_kwargs.kwargs.get("dry_run") is True
@pytest.mark.asyncio
async def test_no_session_dry_run_respects_tool_param(self):
"""When session dry_run is False, tool-level dry_run should be respected."""
tool = RunBlockTool()
session = _make_session(dry_run=False)
mock_block = _make_mock_block()
with (
patch(
"backend.copilot.tools.run_block.prepare_block_for_execution"
) as mock_prep,
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
patch(
"backend.copilot.tools.run_block.get_current_permissions",
return_value=None,
),
patch("backend.copilot.tools.run_block.check_hitl_review") as mock_hitl,
):
mock_prep_result = MagicMock()
mock_prep_result.block = mock_block
mock_prep_result.input_data = {"query": "test"}
mock_prep_result.matched_credentials = {}
mock_prep_result.synthetic_node_id = "node-1"
mock_prep_result.required_non_credential_keys = {"query"}
mock_prep_result.provided_input_keys = {"query"}
mock_prep.return_value = mock_prep_result
mock_hitl.return_value = ("node-exec-1", {"query": "test"})
mock_exec.return_value = MagicMock(
message="Block executed",
success=True,
)
await tool._execute(
user_id="test-user",
session=session,
block_id="test-block-id",
input_data={"query": "test"},
dry_run=False,
)
# Verify execute_block was called with dry_run=False
mock_exec.assert_called_once()
call_kwargs = mock_exec.call_args
assert call_kwargs.kwargs.get("dry_run") is False
# ---------------------------------------------------------------------------
# RunAgentTool tests
# ---------------------------------------------------------------------------
class TestRunAgentToolSessionDryRun:
"""Test that RunAgentTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_forces_agent_dry_run(self):
"""When session dry_run is True, run_agent params.dry_run should be forced True."""
tool = RunAgentTool()
session = _make_session(dry_run=True)
# Mock the graph and dependencies
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.name = "Test Agent"
mock_graph.description = "A test agent"
mock_graph.input_schema = {"properties": {}, "required": []}
mock_graph.trigger_setup_info = None
mock_library_agent = MagicMock()
mock_library_agent.id = "lib-1"
mock_library_agent.graph_id = "graph-1"
mock_library_agent.graph_version = 1
mock_library_agent.name = "Test Agent"
mock_execution = MagicMock()
mock_execution.id = "exec-1"
with (
patch("backend.copilot.tools.run_agent.graph_db"),
patch("backend.copilot.tools.run_agent.library_db"),
patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
return_value=(mock_graph, None),
),
patch(
"backend.copilot.tools.run_agent.match_user_credentials_to_graph",
return_value=({}, []),
),
patch(
"backend.copilot.tools.run_agent.get_or_create_library_agent",
return_value=mock_library_agent,
),
patch("backend.copilot.tools.run_agent.execution_utils") as mock_exec_utils,
patch("backend.copilot.tools.run_agent.track_agent_run_success"),
):
mock_exec_utils.add_graph_execution = AsyncMock(return_value=mock_execution)
await tool._execute(
user_id="test-user",
session=session,
username_agent_slug="user/test-agent",
dry_run=False, # User passed False, but session overrides
use_defaults=True,
)
# Verify add_graph_execution was called with dry_run=True
mock_exec_utils.add_graph_execution.assert_called_once()
call_kwargs = mock_exec_utils.add_graph_execution.call_args
assert call_kwargs.kwargs.get("dry_run") is True
@pytest.mark.asyncio
async def test_session_dry_run_blocks_scheduling(self):
"""When session dry_run is True, scheduling requests should be rejected."""
tool = RunAgentTool()
session = _make_session(dry_run=True)
result = await tool._execute(
user_id="test-user",
session=session,
username_agent_slug="user/test-agent",
schedule_name="daily-run",
cron="0 9 * * *",
dry_run=False, # Session overrides to True
)
assert isinstance(result, ErrorResponse)
assert "dry-run" in result.message.lower()
assert (
"scheduling" in result.message.lower()
or "schedule" in result.message.lower()
)
# ---------------------------------------------------------------------------
# ChatSession model tests
# ---------------------------------------------------------------------------
class TestChatSessionDryRun:
"""Test the dry_run field on ChatSession model."""
def test_new_session_default_dry_run_false(self):
session = ChatSession.new("test-user", dry_run=False)
assert session.dry_run is False
def test_new_session_dry_run_true(self):
session = ChatSession.new("test-user", dry_run=True)
assert session.dry_run is True
def test_new_session_dry_run_false_explicit(self):
session = ChatSession.new("test-user", dry_run=False)
assert session.dry_run is False
# ---------------------------------------------------------------------------
# RunAgentInput tests
# ---------------------------------------------------------------------------
class TestRunAgentInputDryRunOverride:
"""Test that RunAgentInput.dry_run can be mutated by session-level override."""
def test_explicit_dry_run_false(self):
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
assert params.dry_run is False
def test_session_override(self):
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
# Simulate session-level override
params.dry_run = True
assert params.dry_run is True
# ---------------------------------------------------------------------------
# RunMCPToolTool tests
# ---------------------------------------------------------------------------
class TestRunMCPToolToolSessionDryRun:
"""Test that RunMCPToolTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_blocks_mcp_execution(self):
"""When session dry_run is True, MCP tool execution should be skipped."""
tool = RunMCPToolTool()
session = _make_session(dry_run=True)
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="some_tool",
tool_arguments={"key": "value"},
)
assert isinstance(result, MCPToolOutputResponse)
assert result.success is True
assert "dry-run" in result.message
assert result.tool_name == "some_tool"
assert result.result is None
@pytest.mark.asyncio
async def test_session_dry_run_allows_discovery(self):
"""When session dry_run is True, tool discovery (no tool_name) should still work."""
tool = RunMCPToolTool()
session = _make_session(dry_run=True)
# Discovery requires a network call, so we mock the client
with (
patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
return_value=None,
),
patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
return_value=None,
),
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client_cls.return_value = mock_client
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_client.list_tools.return_value = [mock_tool]
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="", # Discovery mode
)
# Discovery should proceed normally
mock_client.initialize.assert_called_once()
mock_client.list_tools.assert_called_once()
assert "Discovered" in result.message
@pytest.mark.asyncio
async def test_no_session_dry_run_allows_execution(self):
"""When session dry_run is False, MCP tool execution should proceed."""
tool = RunMCPToolTool()
session = _make_session(dry_run=False)
with (
patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
return_value=None,
),
patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
return_value=None,
),
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client_cls.return_value = mock_client
mock_result = MagicMock()
mock_result.is_error = False
mock_result.content = [{"type": "text", "text": "hello"}]
mock_client.call_tool.return_value = mock_result
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="some_tool",
tool_arguments={"key": "value"},
)
# Execution should proceed
mock_client.initialize.assert_called_once()
mock_client.call_tool.assert_called_once_with("some_tool", {"key": "value"})
assert isinstance(result, MCPToolOutputResponse)
assert result.success is True
# ---------------------------------------------------------------------------
# Backward-compatibility tests for ChatSessionMetadata deserialization
# ---------------------------------------------------------------------------
class TestChatSessionMetadataBackwardCompat:
"""Verify that sessions created before the dry_run field existed still load.
The ``metadata`` JSON column in the DB may contain ``{}``, ``null``, or a
dict without the ``dry_run`` key for sessions created before the flag was
introduced. These must deserialize without errors and default to
``dry_run=False``.
"""
def test_metadata_default_construction(self):
"""ChatSessionMetadata() with no args should default dry_run=False."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata()
assert meta.dry_run is False
def test_metadata_from_empty_dict(self):
"""Deserializing an empty dict (old-format metadata) should succeed."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata.model_validate({})
assert meta.dry_run is False
def test_metadata_from_dict_without_dry_run_key(self):
"""A metadata dict with other keys but no dry_run should still work."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata.model_validate({"some_future_field": 42})
# dry_run should fall back to default
assert meta.dry_run is False
def test_metadata_round_trip_with_dry_run_false(self):
"""Serialize then deserialize with dry_run=False."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=False)
raw = original.model_dump()
restored = ChatSessionMetadata.model_validate(raw)
assert restored.dry_run is False
def test_metadata_round_trip_with_dry_run_true(self):
"""Serialize then deserialize with dry_run=True."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=True)
raw = original.model_dump()
restored = ChatSessionMetadata.model_validate(raw)
assert restored.dry_run is True
def test_metadata_json_round_trip(self):
"""Serialize to JSON string and back, simulating Redis cache flow."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=True)
json_str = original.model_dump_json()
restored = ChatSessionMetadata.model_validate_json(json_str)
assert restored.dry_run is True
def test_session_dry_run_property_with_default_metadata(self):
"""ChatSession.dry_run returns False when metadata has no dry_run."""
from backend.copilot.model import ChatSessionMetadata
# Simulate building a session with metadata deserialized from an old row
meta = ChatSessionMetadata.model_validate({})
session = _make_session(dry_run=False)
session.metadata = meta
assert session.dry_run is False
def test_session_info_dry_run_property_with_default_metadata(self):
"""ChatSessionInfo.dry_run returns False when metadata is default."""
from datetime import UTC, datetime
from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata
info = ChatSessionInfo(
session_id="old-session-id",
user_id="test-user",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata.model_validate({}),
)
assert info.dry_run is False
def test_session_full_json_round_trip_without_dry_run(self):
"""A full ChatSession JSON round-trip preserves dry_run default."""
session = _make_session(dry_run=False)
json_bytes = session.model_dump_json()
restored = ChatSession.model_validate_json(json_bytes)
assert restored.dry_run is False
assert restored.metadata.dry_run is False
def test_session_full_json_round_trip_with_dry_run(self):
"""A full ChatSession JSON round-trip preserves dry_run=True."""
session = _make_session(dry_run=True)
json_bytes = session.model_dump_json()
restored = ChatSession.model_validate_json(json_bytes)
assert restored.dry_run is True
assert restored.metadata.dry_run is True

View File

@@ -48,9 +48,9 @@ class ValidateAgentGraphTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):

View File

@@ -87,11 +87,10 @@ class WebFetchTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
url: str = "",
extract_text: bool = True,
**kwargs: Any,
) -> ToolResponseBase:
url = url.strip()
url: str = (kwargs.get("url") or "").strip()
extract_text: bool = kwargs.get("extract_text", True)
session_id = session.session_id if session else None
if not url:

View File

@@ -450,9 +450,6 @@ class ListWorkspaceFilesTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
path_prefix: Optional[str] = None,
limit: int = 50,
include_all_sessions: bool = False,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
@@ -461,7 +458,9 @@ class ListWorkspaceFilesTool(BaseTool):
message="Authentication required", session_id=session_id
)
limit = min(limit, 100)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
manager = await get_workspace_manager(user_id, session_id)
@@ -568,12 +567,6 @@ class ReadWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
file_id: Optional[str] = None,
path: Optional[str] = None,
save_to_path: Optional[str] = None,
force_download_url: bool = False,
offset: int = 0,
length: Optional[int] = None,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
@@ -582,8 +575,12 @@ class ReadWorkspaceFileTool(BaseTool):
message="Authentication required", session_id=session_id
)
char_offset: int = max(0, offset)
char_length: Optional[int] = length
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
save_to_path: Optional[str] = kwargs.get("save_to_path")
force_download_url: bool = kwargs.get("force_download_url", False)
char_offset: int = max(0, kwargs.get("offset", 0))
char_length: Optional[int] = kwargs.get("length")
if not file_id and not path:
return ErrorResponse(
@@ -773,13 +770,6 @@ class WriteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
filename: str = "",
source_path: str | None = None,
content: str | None = None,
content_base64: str | None = None,
path: str | None = None,
mime_type: str | None = None,
overwrite: bool = False,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
@@ -788,36 +778,15 @@ class WriteWorkspaceFileTool(BaseTool):
message="Authentication required", session_id=session_id
)
filename: str = kwargs.get("filename", "")
if not filename:
# When ALL parameters are missing, the most likely cause is
# output token truncation: the LLM tried to inline a very large
# file as `content`, the SDK silently truncated the tool call
# arguments to `{}`, and we receive nothing. Return an
# actionable error instead of a generic "filename required".
has_any_content = any(
kwargs.get(k) for k in ("content", "content_base64", "source_path")
)
if not has_any_content:
return ErrorResponse(
message=(
"Tool call appears truncated (no arguments received). "
"This happens when the content is too large for a "
"single tool call. Instead of passing content inline, "
"first write the file to the working directory using "
"bash_exec (e.g. cat > /home/user/file.md << 'EOF'... "
"EOF), then use source_path to copy it to workspace: "
"write_workspace_file(filename='file.md', "
"source_path='/home/user/file.md')"
),
session_id=session_id,
)
return ErrorResponse(
message="Please provide a filename", session_id=session_id
)
source_path_arg: str | None = source_path
content_text: str | None = content
content_b64: str | None = content_base64
source_path_arg: str | None = kwargs.get("source_path")
content_text: str | None = kwargs.get("content")
content_b64: str | None = kwargs.get("content_base64")
resolved = await _resolve_write_content(
content_text,
@@ -827,24 +796,24 @@ class WriteWorkspaceFileTool(BaseTool):
)
if isinstance(resolved, ErrorResponse):
return resolved
content_bytes: bytes = resolved
content: bytes = resolved
max_size = _MAX_FILE_SIZE_MB * 1024 * 1024
if len(content_bytes) > max_size:
if len(content) > max_size:
return ErrorResponse(
message=f"File too large. Maximum size is {_MAX_FILE_SIZE_MB}MB",
session_id=session_id,
)
try:
await scan_content_safe(content_bytes, filename=filename)
await scan_content_safe(content, filename=filename)
manager = await get_workspace_manager(user_id, session_id)
rec = await manager.write_file(
content=content_bytes,
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
path=kwargs.get("path"),
mime_type=kwargs.get("mime_type"),
overwrite=kwargs.get("overwrite", False),
)
# Build informative source label and message.
@@ -868,8 +837,8 @@ class WriteWorkspaceFileTool(BaseTool):
preview: str | None = None
if _is_text_mime(rec.mime_type):
try:
preview = content_bytes[:200].decode("utf-8", errors="replace")
if len(content_bytes) > 200:
preview = content[:200].decode("utf-8", errors="replace")
if len(content) > 200:
preview += "..."
except Exception:
pass
@@ -941,8 +910,6 @@ class DeleteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
file_id: Optional[str] = None,
path: Optional[str] = None,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
@@ -950,6 +917,9 @@ class DeleteWorkspaceFileTool(BaseTool):
return ErrorResponse(
message="Authentication required", session_id=session_id
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path", session_id=session_id

View File

@@ -325,8 +325,6 @@ class _BaseCredentials(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
provider: str
title: Optional[str] = None
is_managed: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
@@ -346,6 +344,7 @@ class OAuth2Credentials(_BaseCredentials):
refresh_token_expires_at: Optional[int] = None
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
scopes: list[str]
metadata: dict[str, Any] = Field(default_factory=dict)
def auth_header(self) -> str:
return f"Bearer {self.access_token.get_secret_value()}"

View File

@@ -3,7 +3,7 @@ import hashlib
import hmac
import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Optional, cast
from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
@@ -21,9 +21,6 @@ from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.credentials_store import IntegrationCredentialsStore
logger = logging.getLogger(__name__)
settings = Settings()
@@ -456,27 +453,6 @@ async def unsubscribe_user_by_token(token: str) -> None:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
async def cleanup_user_managed_credentials(
user_id: str,
store: Optional["IntegrationCredentialsStore"] = None,
) -> None:
"""Revoke all externally-provisioned managed credentials for *user_id*.
Call this before deleting a user account so that external resources
(e.g. AgentMail pods, pod-scoped API keys) are properly cleaned up.
The credential rows themselves are cascade-deleted with the User row.
Pass an existing *store* for testability; when omitted a fresh instance
is created.
"""
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.managed_credentials import cleanup_managed_credentials
if store is None:
store = IntegrationCredentialsStore()
await cleanup_managed_credentials(user_id, store)
async def update_user_timezone(user_id: str, timezone: str) -> User:
"""Update a user's timezone setting."""
try:

View File

@@ -13,7 +13,7 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator
import json
import logging
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from backend.util.clients import get_openai_client
@@ -96,10 +96,6 @@ def build_simulation_prompt(block: Any, input_data: dict[str, Any]) -> tuple[str
input_pins = _describe_schema_pins(input_schema)
output_pins = _describe_schema_pins(output_schema)
output_properties = list(output_schema.get("properties", {}).keys())
# Build a separate list for the "MUST include" instruction that excludes
# "error" — the prompt already tells the LLM to OMIT the error pin unless
# simulating a logical error. Including it in "MUST include" is contradictory.
required_output_properties = [k for k in output_properties if k != "error"]
block_name = getattr(block, "name", type(block).__name__)
block_description = getattr(block, "description", "No description available.")
@@ -121,10 +117,10 @@ Rules:
- Respond with a single JSON object whose keys are EXACTLY the output pin names listed above.
- Assume all credentials and authentication are present and valid. Never simulate authentication failures.
- Make the simulated outputs realistic and consistent with the inputs.
- If there is an "error" pin, OMIT it entirely unless you are simulating a logical error. Only include the "error" pin when there is a genuine error message to report.
- If there is an "error" pin, set it to "" (empty string) unless you are simulating a logical error.
- Do not include any extra keys beyond the output pins.
Output pin names you MUST include: {json.dumps(required_output_properties)}
Output pin names you MUST include: {json.dumps(output_properties)}
"""
safe_inputs = _truncate_input_values(input_data)
@@ -136,7 +132,7 @@ Output pin names you MUST include: {json.dumps(required_output_properties)}
async def simulate_block(
block: Any,
input_data: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
) -> AsyncIterator[tuple[str, Any]]:
"""Simulate block execution using an LLM.
Yields (output_name, output_data) tuples matching the Block.execute() interface.
@@ -176,26 +172,13 @@ async def simulate_block(
if not isinstance(parsed, dict):
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
# Fill missing output pins with defaults.
# Skip empty "error" pins — an empty string means "no error" and
# would only confuse downstream consumers (LLM, frontend).
# Fill missing output pins with defaults
result: dict[str, Any] = {}
for pin_name in output_properties:
if pin_name in parsed:
value = parsed[pin_name]
# Drop empty/blank error pins: they carry no information.
# Uses strip() intentionally so whitespace-only strings
# (e.g. " ", "\n") are also treated as empty.
if (
pin_name == "error"
and isinstance(value, str)
and not value.strip()
):
continue
result[pin_name] = value
elif pin_name != "error":
# Only fill non-error missing pins with None
result[pin_name] = None
result[pin_name] = parsed[pin_name]
else:
result[pin_name] = "" if pin_name == "error" else None
logger.debug(
"simulate_block: block=%s attempt=%d tokens=%s/%s",

View File

@@ -1,6 +1,5 @@
import base64
import hashlib
import logging
import secrets
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
@@ -22,7 +21,6 @@ from backend.data.redis_client import get_redis_async
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
def provider_matches(stored: str, expected: str) -> bool:
@@ -286,7 +284,6 @@ DEFAULT_CREDENTIALS = [
elevenlabs_credentials,
]
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
# Set of providers that have system credentials available
@@ -326,45 +323,20 @@ class IntegrationCredentialsStore:
return get_database_manager_async_client()
# =============== USER-MANAGED CREDENTIALS =============== #
async def _get_persisted_user_creds_unlocked(
self, user_id: str
) -> list[Credentials]:
"""Return only the persisted (user-stored) credentials — no side effects.
**Caller must already hold ``locked_user_integrations(user_id)``.**
"""
return list((await self._get_user_integrations(user_id)).credentials)
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
async with await self.locked_user_integrations(user_id):
# Check system/managed IDs without triggering provisioning
if credentials.id in SYSTEM_CREDENTIAL_IDS:
if await self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
persisted = await self._get_persisted_user_creds_unlocked(user_id)
if any(c.id == credentials.id for c in persisted):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
await self._set_user_integration_creds(user_id, [*persisted, credentials])
await self._set_user_integration_creds(
user_id, [*(await self.get_all_creds(user_id)), credentials]
)
async def get_all_creds(self, user_id: str) -> list[Credentials]:
"""Public entry point — acquires lock, then delegates."""
async with await self.locked_user_integrations(user_id):
return await self._get_all_creds_unlocked(user_id)
async def _get_all_creds_unlocked(self, user_id: str) -> list[Credentials]:
"""Return all credentials for *user_id*.
**Caller must already hold ``locked_user_integrations(user_id)``.**
"""
user_integrations = await self._get_user_integrations(user_id)
all_credentials = list(user_integrations.credentials)
users_credentials = (await self._get_user_integrations(user_id)).credentials
all_credentials = users_credentials
# These will always be added
all_credentials.append(ollama_credentials)
@@ -445,22 +417,13 @@ class IntegrationCredentialsStore:
return list(set(c.provider for c in credentials))
async def update_creds(self, user_id: str, updated: Credentials) -> None:
if updated.id in SYSTEM_CREDENTIAL_IDS:
raise ValueError(
f"System credential #{updated.id} cannot be updated directly"
)
async with await self.locked_user_integrations(user_id):
persisted = await self._get_persisted_user_creds_unlocked(user_id)
current = next((c for c in persisted if c.id == updated.id), None)
current = await self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if current.is_managed:
raise ValueError(
f"AutoGPT-managed credential #{updated.id} cannot be updated"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
@@ -480,53 +443,22 @@ class IntegrationCredentialsStore:
f"to more restrictive set of scopes {updated.scopes}"
)
# Update only persisted credentials — no side-effectful provisioning
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c for c in persisted
updated if c.id == updated.id else c
for c in await self.get_all_creds(user_id)
]
await self._set_user_integration_creds(user_id, updated_credentials_list)
async def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
if credentials_id in SYSTEM_CREDENTIAL_IDS:
raise ValueError(f"System credential #{credentials_id} cannot be deleted")
async with await self.locked_user_integrations(user_id):
persisted = await self._get_persisted_user_creds_unlocked(user_id)
target = next((c for c in persisted if c.id == credentials_id), None)
if target and target.is_managed:
raise ValueError(
f"AutoGPT-managed credential #{credentials_id} cannot be deleted"
)
filtered_credentials = [c for c in persisted if c.id != credentials_id]
filtered_credentials = [
c for c in await self.get_all_creds(user_id) if c.id != credentials_id
]
await self._set_user_integration_creds(user_id, filtered_credentials)
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
async def has_managed_credential(self, user_id: str, provider: str) -> bool:
"""Check if a managed credential exists for *provider*."""
user_integrations = await self._get_user_integrations(user_id)
return any(
c.provider == provider and c.is_managed
for c in user_integrations.credentials
)
async def add_managed_credential(
self, user_id: str, credential: Credentials
) -> None:
"""Upsert a managed credential.
Removes any existing managed credential for the same provider,
then appends the new one. The credential MUST have is_managed=True.
"""
if not credential.is_managed:
raise ValueError("credential.is_managed must be True")
async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.credentials = [
c
for c in user_integrations.credentials
if not (c.provider == credential.provider and c.is_managed)
]
user_integrations.credentials.append(credential)
async def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
"""Set the Ayrshare profile key for a user.

View File

@@ -1,188 +0,0 @@
"""Generic infrastructure for system-provided, per-user managed credentials.
Managed credentials are provisioned automatically by the platform (e.g. an
AgentMail pod-scoped API key) and stored alongside regular user credentials
with ``is_managed=True``. Users cannot update or delete them.
New integrations register a :class:`ManagedCredentialProvider` at import time;
the two entry-points consumed by the rest of the application are:
* :func:`ensure_managed_credentials` fired as a background task from the
credential-listing endpoints (non-blocking).
* :func:`cleanup_managed_credentials` called during account deletion to
revoke external resources (API keys, pods, etc.).
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from cachetools import TTLCache
if TYPE_CHECKING:
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Abstract provider
# ---------------------------------------------------------------------------
class ManagedCredentialProvider(ABC):
"""Base class for integrations that auto-provision per-user credentials."""
provider_name: str
"""Must match the ``provider`` field on the resulting credential."""
@abstractmethod
async def is_available(self) -> bool:
"""Return ``True`` when the org-level configuration is present."""
@abstractmethod
async def provision(self, user_id: str) -> Credentials:
"""Create external resources and return a credential.
The returned credential **must** have ``is_managed=True``.
"""
@abstractmethod
async def deprovision(self, user_id: str, credential: Credentials) -> None:
"""Revoke external resources during account deletion."""
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_PROVIDERS: dict[str, ManagedCredentialProvider] = {}
# Users whose managed credentials have already been verified recently.
# Avoids redundant DB checks on every GET /credentials call.
# maxsize caps memory; TTL re-checks periodically (e.g. when new providers
# are added). ~100K entries ≈ 4-8 MB.
_provisioned_users: TTLCache[str, bool] = TTLCache(maxsize=100_000, ttl=3600)
def register_managed_provider(provider: ManagedCredentialProvider) -> None:
_PROVIDERS[provider.provider_name] = provider
def get_managed_provider(name: str) -> ManagedCredentialProvider | None:
return _PROVIDERS.get(name)
def get_managed_providers() -> dict[str, ManagedCredentialProvider]:
return dict(_PROVIDERS)
# ---------------------------------------------------------------------------
# Public helpers
# ---------------------------------------------------------------------------
async def _ensure_one(
user_id: str,
store: IntegrationCredentialsStore,
name: str,
provider: ManagedCredentialProvider,
) -> bool:
"""Provision a single managed credential under a distributed Redis lock.
Returns ``True`` if the credential already exists or was successfully
provisioned, ``False`` on transient failure so the caller knows not to
cache the user as fully provisioned.
"""
try:
if not await provider.is_available():
return True
# Use a distributed Redis lock so the check-then-provision operation
# is atomic across all workers, preventing duplicate external
# resource provisioning (e.g. AgentMail API keys).
locks = await store.locks()
key = (f"user:{user_id}", f"managed-provision:{name}")
async with locks.locked(key):
# Re-check under lock to avoid duplicate provisioning.
if await store.has_managed_credential(user_id, name):
return True
credential = await provider.provision(user_id)
await store.add_managed_credential(user_id, credential)
logger.info(
"Provisioned managed credential for provider=%s user=%s",
name,
user_id,
)
return True
except Exception:
logger.warning(
"Failed to provision managed credential for provider=%s user=%s",
name,
user_id,
exc_info=True,
)
return False
async def ensure_managed_credentials(
user_id: str,
store: IntegrationCredentialsStore,
) -> None:
"""Provision missing managed credentials for *user_id*.
Fired as a non-blocking background task from the credential-listing
endpoints. Failures are logged but never propagated — the user simply
will not see the managed credential until the next page load.
Skips entirely if this user has already been checked during the current
process lifetime (in-memory cache). Resets on restart — just a
performance optimisation, not a correctness guarantee.
Providers are checked concurrently via ``asyncio.gather``.
"""
if user_id in _provisioned_users:
return
results = await asyncio.gather(
*(_ensure_one(user_id, store, n, p) for n, p in _PROVIDERS.items())
)
# Only cache the user as provisioned when every provider succeeded or
# was already present. A transient failure (network timeout, Redis
# blip) returns False, so the next page load will retry.
if all(results):
_provisioned_users[user_id] = True
async def cleanup_managed_credentials(
user_id: str,
store: IntegrationCredentialsStore,
) -> None:
"""Revoke all external managed resources for a user being deleted."""
all_creds = await store.get_all_creds(user_id)
managed = [c for c in all_creds if c.is_managed]
for cred in managed:
provider = _PROVIDERS.get(cred.provider)
if not provider:
logger.warning(
"No managed provider registered for %s — skipping cleanup",
cred.provider,
)
continue
try:
await provider.deprovision(user_id, cred)
logger.info(
"Deprovisioned managed credential for provider=%s user=%s",
cred.provider,
user_id,
)
except Exception:
logger.error(
"Failed to deprovision %s for user %s",
cred.provider,
user_id,
exc_info=True,
)

View File

@@ -1,17 +0,0 @@
"""Managed credential providers.
Call :func:`register_all` at application startup (e.g. in ``rest_api.py``)
to populate the provider registry before any requests are processed.
"""
from backend.integrations.managed_credentials import (
get_managed_provider,
register_managed_provider,
)
from backend.integrations.managed_providers.agentmail import AgentMailManagedProvider
def register_all() -> None:
"""Register every built-in managed credential provider (idempotent)."""
if get_managed_provider(AgentMailManagedProvider.provider_name) is None:
register_managed_provider(AgentMailManagedProvider())

View File

@@ -1,90 +0,0 @@
"""AgentMail managed credential provider.
Uses the org-level AgentMail API key to create a per-user pod and a
pod-scoped API key. The pod key is stored as an ``is_managed``
credential so it appears automatically in block credential dropdowns.
"""
from __future__ import annotations
import logging
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.managed_credentials import ManagedCredentialProvider
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class AgentMailManagedProvider(ManagedCredentialProvider):
provider_name = "agent_mail"
async def is_available(self) -> bool:
return bool(settings.secrets.agentmail_api_key)
async def provision(self, user_id: str) -> Credentials:
from agentmail import AsyncAgentMail
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
# client_id makes pod creation idempotent — if a pod already exists
# for this user_id the SDK returns the existing pod.
pod = await client.pods.create(client_id=user_id, name=f"{user_id}-pod")
# NOTE: api_keys.create() is NOT idempotent. If the caller retries
# after a partial failure (pod created, key created, but store write
# failed), a second key will be created and the first becomes orphaned
# on AgentMail's side. The double-check pattern in _ensure_one
# (has_managed_credential under lock) prevents this in normal flow;
# only a crash between key creation and store write can cause it.
api_key_obj = await client.pods.api_keys.create(
pod_id=pod.pod_id, name=f"{user_id}-agpt-managed"
)
return APIKeyCredentials(
provider=self.provider_name,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr(api_key_obj.api_key),
expires_at=None,
is_managed=True,
metadata={"pod_id": pod.pod_id},
)
async def deprovision(self, user_id: str, credential: Credentials) -> None:
from agentmail import AsyncAgentMail
pod_id = credential.metadata.get("pod_id")
if not pod_id:
logger.warning(
"Managed credential for user %s has no pod_id in metadata — "
"skipping AgentMail cleanup",
user_id,
)
return
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
try:
# Verify the pod actually belongs to this user before deleting,
# as a safety measure against cross-user deletion via the
# org-level API key.
pod = await client.pods.get(pod_id=pod_id)
if getattr(pod, "client_id", None) and pod.client_id != user_id:
logger.error(
"Pod %s client_id=%s does not match user %s"
"refusing to delete",
pod_id,
pod.client_id,
user_id,
)
return
await client.pods.delete(pod_id=pod_id)
except Exception:
logger.warning(
"Failed to delete AgentMail pod %s for user %s",
pod_id,
user_id,
exc_info=True,
)

Some files were not shown because too many files have changed in this diff Show More