mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
6 Commits
feat/ask-q
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3db2a944f7 | ||
|
|
59192102a6 | ||
|
|
65cca9bef8 | ||
|
|
6b32e43d84 | ||
|
|
b73d05c23e | ||
|
|
8277cce835 |
@@ -1 +0,0 @@
|
||||
../.claude/skills
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1 +1,120 @@
|
||||
@AGENTS.md
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1 +1,227 @@
|
||||
@AGENTS.md
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -11,7 +11,7 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -20,7 +20,6 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -113,25 +112,12 @@ class StreamChatRequest(BaseModel):
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -152,7 +138,6 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -263,7 +248,6 @@ async def list_sessions(
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -272,28 +256,22 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -442,7 +420,6 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -1197,7 +1174,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -469,60 +469,3 @@ def test_suggested_prompts_empty_prompts(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -146,21 +146,6 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -247,11 +232,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -382,9 +367,7 @@ class AutoPilotBlock(Block):
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id, dry_run=input_data.dry_run
|
||||
)
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
|
||||
@@ -2,8 +2,6 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -469,8 +467,7 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -480,23 +477,16 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.options:
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -514,13 +504,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -300,27 +300,13 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, options=opts
|
||||
name="choice", value=None, placeholder_values=options
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
assert schema.get("enum") == options
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -343,11 +329,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -358,36 +344,3 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -18,7 +18,6 @@ import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.context import set_execution_context
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -458,9 +457,6 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
# Propagate execution context so tool handlers can read session-level flags.
|
||||
set_execution_context(user_id, session)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
|
||||
@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -18,13 +18,7 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionInfo,
|
||||
ChatSessionMetadata,
|
||||
invalidate_session_cache,
|
||||
)
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +35,6 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
metadata: ChatSessionMetadata | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -50,7 +43,6 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
@@ -65,12 +57,7 @@ async def update_chat_session(
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's mutable fields.
|
||||
|
||||
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
|
||||
it is set once at creation time and treated as immutable for the lifetime
|
||||
of the session.
|
||||
"""
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
|
||||
@@ -123,7 +123,6 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -142,7 +141,6 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -154,12 +152,8 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -129,15 +129,8 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -148,8 +141,6 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -185,96 +176,6 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -46,16 +46,6 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatSessionMetadata(BaseModel):
|
||||
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
|
||||
|
||||
Add new session-level flags here instead of adding DB columns —
|
||||
no migration required for new fields as long as a default is provided.
|
||||
"""
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -100,12 +90,6 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
@property
|
||||
def dry_run(self) -> bool:
|
||||
"""Convenience accessor for ``metadata.dry_run``."""
|
||||
return self.metadata.dry_run
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
@@ -119,10 +103,6 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Parse typed metadata from the JSON column.
|
||||
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
|
||||
metadata = ChatSessionMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
@@ -148,7 +128,6 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -156,7 +135,7 @@ class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str, *, dry_run: bool) -> Self:
|
||||
def new(cls, user_id: str) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -166,7 +145,6 @@ class ChatSession(ChatSessionInfo):
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metadata=ChatSessionMetadata(dry_run=dry_run),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -554,7 +532,6 @@ async def _save_session_to_db(
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
@@ -632,27 +609,21 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID.
|
||||
dry_run: When True, run_block and run_agent tool calls in this
|
||||
session are forced to use dry-run simulation mode.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id, dry_run=dry_run)
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
|
||||
@@ -46,7 +46,7 @@ messages = [
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123", dry_run=False)
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
@@ -57,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -75,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
@@ -90,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
@@ -241,7 +241,7 @@ _raw_tc2 = {
|
||||
|
||||
def test_add_tool_call_appends_to_existing_assistant():
|
||||
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session = ChatSession.new(user_id="u")
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="working on it"),
|
||||
@@ -254,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
|
||||
|
||||
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
"""When there's no current-turn assistant, a new one is created."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session = ChatSession.new(user_id="u")
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
]
|
||||
@@ -267,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
|
||||
def test_add_tool_call_does_not_cross_user_boundary():
|
||||
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session = ChatSession.new(user_id="u")
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="old turn"),
|
||||
ChatMessage(role="user", content="new message"),
|
||||
@@ -282,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
|
||||
|
||||
def test_add_tool_call_multiple_times():
|
||||
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session = ChatSession.new(user_id="u")
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="doing stuff"),
|
||||
@@ -300,7 +300,7 @@ def test_add_tool_call_multiple_times():
|
||||
|
||||
def test_to_openai_messages_merges_split_assistants():
|
||||
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session = ChatSession.new(user_id="u")
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="build agent"),
|
||||
ChatMessage(role="assistant", content="Let me build that"),
|
||||
@@ -352,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
|
||||
@@ -66,7 +66,6 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
|
||||
@@ -107,13 +107,6 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Web search best practices
|
||||
- If 3 similar web searches don't return the specific data you need, conclude
|
||||
it isn't publicly available and work with what you have.
|
||||
- Prefer fewer, well-targeted searches over many variations of the same query.
|
||||
- When spawning sub-agents for research, ensure each has a distinct
|
||||
non-overlapping scope to avoid redundant searches.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Tests for agent generation guide — verifies clarification section."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
@@ -3,29 +3,6 @@
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Clarifying — Before or During Building
|
||||
|
||||
Use `ask_question` whenever the user's intent is ambiguous — whether
|
||||
that's before starting or midway through the workflow. Common moments:
|
||||
|
||||
- **Before building**: output format, delivery channel, data source, or
|
||||
trigger is unspecified.
|
||||
- **During block discovery**: multiple blocks could fit and the user
|
||||
should choose.
|
||||
- **During JSON generation**: a wiring decision depends on user
|
||||
preference.
|
||||
|
||||
Steps:
|
||||
1. Call `find_block` (or another discovery tool) to learn what the
|
||||
platform actually supports for the ambiguous dimension.
|
||||
2. Call `ask_question` with a concrete question listing the discovered
|
||||
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
|
||||
which should the agent use for delivery?").
|
||||
3. **Wait for the user's answer** before continuing.
|
||||
|
||||
**Skip this** when the goal already specifies all dimensions (e.g.
|
||||
"scrape prices from Amazon and email me daily").
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
@@ -97,8 +74,8 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
return ChatSession.new(user_id="test-user")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -275,7 +275,7 @@ class TestCompactionE2E:
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user", dry_run=False)
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
@@ -376,7 +376,7 @@ class TestCompactionE2E:
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test", dry_run=False)
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestFlattenAssistantContent:
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == ""
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
@@ -47,22 +47,19 @@ class TestFlattenAssistantContent:
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_dropped(self):
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
# Unknown block types are dropped to prevent model mimicry
|
||||
assert "[__image__]" not in result
|
||||
assert "base64" not in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
@@ -282,8 +279,7 @@ class TestTranscriptToMessages:
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "read_file" not in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
|
||||
@@ -49,22 +49,22 @@ def test_format_assistant_tool_calls():
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
# Assistant with no content and tool_calls omitted produces no lines
|
||||
assert result is None
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool output: {"result": "ok"}' in result
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool output: " in result
|
||||
assert "Tool result: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
@@ -84,8 +84,8 @@ def test_format_full_conversation():
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
# tool_calls are omitted to prevent model mimicry
|
||||
assert "Tool output:" in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -77,12 +76,6 @@ class SDKResponseAdapter:
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
elif sdk_message.subtype == "task_progress":
|
||||
# Emit a heartbeat so publish_chunk is called during long
|
||||
# sub-agent runs. Without this, the Redis stream and meta
|
||||
# key TTLs expire during gaps where no real chunks are
|
||||
# produced (task_progress events were previously silent).
|
||||
responses.append(StreamHeartbeat())
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
|
||||
@@ -18,7 +18,6 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -29,7 +28,6 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .compaction import compaction_events
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
@@ -61,14 +59,6 @@ def test_system_non_init_emits_nothing():
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_task_progress_emits_heartbeat():
|
||||
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamHeartbeat)
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
@@ -690,102 +680,3 @@ def test_already_resolved_tool_skipped_in_user_message():
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
|
||||
# -- _end_text_if_open before compaction -------------------------------------
|
||||
|
||||
|
||||
def test_end_text_if_open_emits_text_end_before_finish_step():
|
||||
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
|
||||
|
||||
When ``emit_end_if_ready`` fires compaction events while a text block is
|
||||
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
|
||||
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
|
||||
and raises "Received text-end for missing text part".
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a text block by processing an AssistantMessage with text
|
||||
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.has_started_text
|
||||
assert not adapter.has_ended_text
|
||||
|
||||
# Simulate what service.py does before yielding compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
combined = pre_close + list(compaction_events("Compacted transcript"))
|
||||
|
||||
text_end_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
|
||||
)
|
||||
finish_step_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
|
||||
)
|
||||
|
||||
assert text_end_idx is not None, "StreamTextEnd must be present"
|
||||
assert finish_step_idx is not None, "StreamFinishStep must be present"
|
||||
assert text_end_idx < finish_step_idx, (
|
||||
f"StreamTextEnd (idx={text_end_idx}) must precede "
|
||||
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
|
||||
"clears activeTextParts before text-end arrives"
|
||||
)
|
||||
|
||||
|
||||
def test_step_open_must_reset_after_compaction_finish_step():
|
||||
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
|
||||
|
||||
Compaction events bypass the adapter, so service.py must explicitly clear
|
||||
step_open after yielding a StreamFinishStep from compaction. Without this,
|
||||
the next AssistantMessage skips StreamStartStep because the adapter still
|
||||
thinks a step is open.
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a step + text block via an AssistantMessage
|
||||
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.step_open is True
|
||||
|
||||
# Simulate what service.py does: close text, then check compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
|
||||
events = list(compaction_events("Compacted transcript"))
|
||||
if any(isinstance(ev, StreamFinishStep) for ev in events):
|
||||
adapter.step_open = False
|
||||
|
||||
assert (
|
||||
adapter.step_open is False
|
||||
), "step_open must be False after compaction emits StreamFinishStep"
|
||||
|
||||
# Next AssistantMessage must open a new step
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
|
||||
results = adapter.convert_message(msg2)
|
||||
assert any(
|
||||
isinstance(r, StreamStartStep) for r in results
|
||||
), "A new StreamStartStep must be emitted after compaction closed the step"
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_when_no_text_open():
|
||||
"""_end_text_if_open emits nothing when no text block is open."""
|
||||
adapter = _adapter()
|
||||
results: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(results)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
"""_end_text_if_open emits nothing when the text block is already closed."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
# Close it once
|
||||
first: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(first)
|
||||
assert len(first) == 1
|
||||
assert isinstance(first[0], StreamTextEnd)
|
||||
# Second call must be a no-op
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
@@ -904,14 +904,14 @@ class TestTranscriptEdgeCases:
|
||||
assert restored[1]["content"] == "Second"
|
||||
|
||||
def test_flatten_assistant_with_only_tool_use(self):
|
||||
"""Assistant message with only tool_use blocks (no text) flattens to empty."""
|
||||
"""Assistant message with only tool_use blocks (no text)."""
|
||||
blocks = [
|
||||
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
|
||||
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert result == ""
|
||||
assert "[tool_use: bash]" in result
|
||||
assert "[tool_use: read]" in result
|
||||
|
||||
def test_flatten_tool_result_nested_image(self):
|
||||
"""Tool result containing image blocks uses placeholder."""
|
||||
@@ -1414,261 +1414,3 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# Verify user-friendly message (not raw SDK text)
|
||||
assert "Authentication" in errors[0].errorText
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_prompt_too_long_triggers_compaction(self):
|
||||
"""CLI returns ResultMessage(subtype="error") with "Prompt is too long".
|
||||
|
||||
When the Claude CLI rejects the prompt pre-API (model=<synthetic>,
|
||||
duration_api_ms=0), it sends a ResultMessage with is_error=True
|
||||
instead of raising a Python exception. The retry loop must still
|
||||
detect this as a context-length error and trigger compaction.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="error",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] == 1:
|
||||
# First attempt: CLI returns error ResultMessage
|
||||
return self._make_client_mock(result_message=error_result)
|
||||
# Second attempt (after compaction): succeeds
|
||||
return self._make_client_mock(result_message=success_result)
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (CLI error ResultMessage "
|
||||
f"should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
|
||||
|
||||
The SDK internally compacts but the transcript is still too long. It
|
||||
returns subtype="success" (process completed) with result="Prompt is
|
||||
too long" (the actual rejection message). The retry loop must detect
|
||||
this as a context-length error and trigger compaction — the subtype
|
||||
"success" must not fool it into treating this as a real response.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="success",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
yield error_result
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
|
||||
f"result should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
|
||||
|
||||
The SDK returns error type "invalid_request" but puts the actual
|
||||
rejection message ("Prompt is too long") in the content blocks.
|
||||
The retry loop must detect this via content inspection (sdk_error
|
||||
being set confirms it's an error message, not user content).
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
# SDK returns invalid_request with "Prompt is too long" in content.
|
||||
# ResultMessage.result is a non-PTL value ("done") to isolate
|
||||
# the AssistantMessage content detection path exclusively.
|
||||
yield AssistantMessage(
|
||||
content=[TextBlock(text="Prompt is too long")],
|
||||
model="<synthetic>",
|
||||
error="invalid_request",
|
||||
)
|
||||
yield ResultMessage(
|
||||
subtype="success",
|
||||
result="done",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
|
||||
f"too long' should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -313,7 +313,8 @@ def create_security_hooks(
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
|
||||
@@ -11,11 +11,7 @@ import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .security_hooks import (
|
||||
_validate_tool_access,
|
||||
_validate_user_isolation,
|
||||
create_security_hooks,
|
||||
)
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
@@ -224,6 +220,8 @@ def test_bash_builtin_blocked_message_clarity():
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return (pre, post, post_failure) handlers."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
|
||||
@@ -59,14 +59,11 @@ from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
@@ -84,9 +81,11 @@ from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
reset_tool_failure_counters,
|
||||
set_execution_context,
|
||||
@@ -116,10 +115,9 @@ _MAX_STREAM_ATTEMPTS = 3
|
||||
|
||||
# Hard circuit breaker: abort the stream if the model sends this many
|
||||
# consecutive tool calls with empty parameters (a sign of context
|
||||
# saturation or serialization failure). The MCP wrapper now returns
|
||||
# guidance on the first empty call, giving the model a chance to
|
||||
# self-correct. The limit is generous to allow recovery attempts.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 5
|
||||
# saturation or serialization failure). Empty input ({}) is never
|
||||
# legitimate — even one is suspicious, three is conclusive.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 3
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
@@ -748,11 +746,15 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
elif msg.role == "assistant":
|
||||
if msg.content:
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
# Omit tool_calls — any text representation gets mimicked
|
||||
# by the model. Tool results below provide the context.
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_name = func.get("name", "unknown")
|
||||
tool_args = func.get("arguments", "")
|
||||
lines.append(f"You called tool: {tool_name}({tool_args})")
|
||||
elif msg.role == "tool":
|
||||
content = msg.content or ""
|
||||
lines.append(f"Tool output: {content[:500]}")
|
||||
lines.append(f"Tool result: {content}")
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
@@ -1212,14 +1214,6 @@ async def _run_stream_attempt(
|
||||
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
# --- Intermediate persistence tracking ---
|
||||
# Flush session messages to DB periodically so page reloads show progress
|
||||
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
|
||||
_last_flush_time = time.monotonic()
|
||||
_msgs_since_flush = 0
|
||||
_FLUSH_INTERVAL_SECONDS = 30.0
|
||||
_FLUSH_MESSAGE_THRESHOLD = 10
|
||||
|
||||
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
|
||||
# suppress SDK cleanup errors that occur when the SSE client disconnects
|
||||
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
|
||||
@@ -1306,27 +1300,6 @@ async def _run_stream_attempt(
|
||||
error_preview,
|
||||
)
|
||||
|
||||
# Intercept prompt-too-long errors surfaced as
|
||||
# AssistantMessage.error (not as a Python exception).
|
||||
# Re-raise so the outer retry loop can compact the
|
||||
# transcript and retry with reduced context.
|
||||
# Check both error_text and error_preview: sdk_error
|
||||
# being set confirms this is an error message (not user
|
||||
# content), so checking content is safe. The actual
|
||||
# error description (e.g. "Prompt is too long") may be
|
||||
# in the content, not the error type field
|
||||
# (e.g. error="invalid_request", content="Prompt is
|
||||
# too long").
|
||||
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
|
||||
Exception(error_preview)
|
||||
):
|
||||
logger.warning(
|
||||
"%s Prompt-too-long detected via AssistantMessage "
|
||||
"error — raising for retry",
|
||||
ctx.log_prefix,
|
||||
)
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Intercept transient API errors (socket closed,
|
||||
# ECONNRESET) — replace the raw message with a
|
||||
# user-friendly error text and use the retryable
|
||||
@@ -1354,17 +1327,28 @@ async def _run_stream_attempt(
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# Determine if the message is a tool-only batch (all content
|
||||
# Parallel tool execution: pre-launch every ToolUseBlock as an
|
||||
# asyncio.Task the moment its AssistantMessage arrives. The SDK
|
||||
# sends one AssistantMessage per tool call when issuing parallel
|
||||
# calls, so each message is pre-launched independently. The MCP
|
||||
# handlers will await the already-running task instead of executing
|
||||
# fresh, making all concurrent tool calls run in parallel.
|
||||
#
|
||||
# Also determine if the message is a tool-only batch (all content
|
||||
# items are ToolUseBlocks) — such messages have no text output yet,
|
||||
# so we skip the wait_for_stash flush below.
|
||||
#
|
||||
# Note: parallel execution of tools is handled natively by the
|
||||
# SDK CLI via readOnlyHint annotations on tool definitions.
|
||||
is_tool_only = False
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
|
||||
is_tool_only = all(
|
||||
isinstance(item, ToolUseBlock) for item in sdk_msg.content
|
||||
)
|
||||
is_tool_only = True
|
||||
# NOTE: Pre-launches are sequential (each await completes
|
||||
# file-ref expansion before the next starts). This is fine
|
||||
# since expansion is typically sub-ms; a future optimisation
|
||||
# could gather all pre-launches concurrently.
|
||||
for tool_use in sdk_msg.content:
|
||||
if isinstance(tool_use, ToolUseBlock):
|
||||
await pre_launch_tool_call(tool_use.name, tool_use.input)
|
||||
else:
|
||||
is_tool_only = False
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
@@ -1421,16 +1405,6 @@ async def _run_stream_attempt(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Check for prompt-too-long regardless of subtype — the
|
||||
# SDK may return subtype="success" with result="Prompt is
|
||||
# too long" when the CLI rejects the prompt before calling
|
||||
# the API (cost_usd=0, no tokens consumed). If we only
|
||||
# check the "error" subtype path, the stream appears to
|
||||
# complete normally, the synthetic error text is stored
|
||||
# in the transcript, and the session grows without bound.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
@@ -1462,23 +1436,6 @@ async def _run_stream_attempt(
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
|
||||
if compact_result.events:
|
||||
# Compaction events end with StreamFinishStep, which maps to
|
||||
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
|
||||
# Close any open text block BEFORE the compaction events so
|
||||
# the text-end arrives before finish-step, preventing
|
||||
# "text-end for missing text part" errors on the frontend.
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(pre_close)
|
||||
# Compaction events bypass the adapter, so sync step state
|
||||
# when a StreamFinishStep is present — otherwise the adapter
|
||||
# will skip StreamStartStep on the next AssistantMessage.
|
||||
if any(
|
||||
isinstance(ev, StreamFinishStep) for ev in compact_result.events
|
||||
):
|
||||
state.adapter.step_open = False
|
||||
for r in pre_close:
|
||||
yield r
|
||||
for ev in compact_result.events:
|
||||
yield ev
|
||||
entries_replaced = False
|
||||
@@ -1525,34 +1482,6 @@ async def _run_stream_attempt(
|
||||
model=sdk_msg.model,
|
||||
)
|
||||
|
||||
# --- Intermediate persistence ---
|
||||
# Flush session messages to DB periodically so page reloads
|
||||
# show progress during long-running turns.
|
||||
_msgs_since_flush += 1
|
||||
now = time.monotonic()
|
||||
if (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(ctx.session))
|
||||
logger.debug(
|
||||
"%s Intermediate flush: %d messages "
|
||||
"(msgs_since=%d, elapsed=%.1fs)",
|
||||
ctx.log_prefix,
|
||||
len(ctx.session.messages),
|
||||
_msgs_since_flush,
|
||||
now - _last_flush_time,
|
||||
)
|
||||
except Exception as flush_err:
|
||||
logger.warning(
|
||||
"%s Intermediate flush failed: %s",
|
||||
ctx.log_prefix,
|
||||
flush_err,
|
||||
)
|
||||
_last_flush_time = now
|
||||
_msgs_since_flush = 0
|
||||
|
||||
if acc.stream_completed:
|
||||
break
|
||||
finally:
|
||||
@@ -2079,22 +2008,13 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
try:
|
||||
async for event in _run_stream_attempt(stream_ctx, state):
|
||||
if not isinstance(
|
||||
event,
|
||||
(
|
||||
StreamHeartbeat,
|
||||
# Compaction UI events are cosmetic and must not
|
||||
# block retry — they're emitted before the SDK
|
||||
# query on compacted attempts.
|
||||
StreamStartStep,
|
||||
StreamFinishStep,
|
||||
StreamToolInputStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
if not isinstance(event, StreamHeartbeat):
|
||||
events_yielded += 1
|
||||
yield event
|
||||
# Cancel any pre-launched tasks that were never dispatched
|
||||
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
|
||||
# with the three error-path await cancel_pending_tool_tasks() calls.
|
||||
await cancel_pending_tool_tasks()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -2103,6 +2023,9 @@ async def stream_chat_completion_sdk(
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
)
|
||||
# Cancel any pre-launched tasks so they don't continue executing
|
||||
# against a rolled-back or abandoned session.
|
||||
await cancel_pending_tool_tasks()
|
||||
raise
|
||||
except _HandledStreamError as exc:
|
||||
# _run_stream_attempt already yielded a StreamError and
|
||||
@@ -2134,6 +2057,8 @@ async def stream_chat_completion_sdk(
|
||||
retryable=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
# Cancel any pre-launched tasks from the failed attempt.
|
||||
await cancel_pending_tool_tasks()
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
@@ -2150,6 +2075,9 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Cancel any pre-launched tasks from the failed attempt so they
|
||||
# don't continue executing against the rolled-back session.
|
||||
await cancel_pending_tool_tasks()
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
|
||||
@@ -392,7 +392,7 @@ class TestFlattenThinkingBlocks:
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text survives flattening; thinking and tool_use dropped."""
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
@@ -403,8 +403,7 @@ class TestFlattenThinkingBlocks:
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -14,7 +14,6 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
@@ -54,6 +53,14 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -108,6 +115,7 @@ def set_execution_context(
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
@@ -124,6 +132,48 @@ def reset_stash_event() -> None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
@@ -199,6 +249,10 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -217,13 +271,95 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
@@ -319,7 +455,83 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -436,19 +648,9 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
_PARALLEL_ANNOTATION = ToolAnnotations(readOnlyHint=True)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
All tools are annotated with ``readOnlyHint=True`` so the SDK CLI
|
||||
dispatches concurrent tool calls in parallel rather than sequentially.
|
||||
This is a deliberate override: even side-effect tools use the hint
|
||||
because the MCP tools are already individually sandboxed and the
|
||||
pre-launch duplicate-execution bug (SECRT-2204) is worse than
|
||||
sequential dispatch.
|
||||
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
that route directly to the E2B sandbox filesystem, and the caller should
|
||||
disable the corresponding SDK built-in tools via
|
||||
@@ -466,28 +668,6 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Empty tool args = model's output was truncated by the API's
|
||||
# max_tokens limit. Instead of letting the tool fail with a
|
||||
# confusing error (and eventually tripping the circuit breaker),
|
||||
# return clear guidance so the model can self-correct.
|
||||
if not args and input_schema and input_schema.get("required"):
|
||||
logger.warning(
|
||||
"[MCP] %s called with empty args (likely output "
|
||||
"token truncation) — returning guidance",
|
||||
tool_name,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Your call to {tool_name} had empty arguments — "
|
||||
f"this means your previous response was too long and "
|
||||
f"the tool call input was truncated by the API. "
|
||||
f"To fix this: break your work into smaller steps. "
|
||||
f"For large content, first write it to a file using "
|
||||
f"bash_exec with cat >> (append section by section), "
|
||||
f"then pass it via @@agptfile:filename reference. "
|
||||
f"Do NOT retry with the same approach — it will "
|
||||
f"be truncated again."
|
||||
)
|
||||
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
@@ -538,35 +718,24 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
# All tools annotated readOnlyHint=True to enable parallel dispatch.
|
||||
# The SDK CLI uses this hint to dispatch concurrent tool calls in
|
||||
# parallel rather than sequentially. Side-effect safety is ensured
|
||||
# by the tool implementations themselves (idempotency, credentials).
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
if use_e2b:
|
||||
for name, desc, schema, handler in E2B_FILE_TOOLS:
|
||||
decorated = tool(
|
||||
name,
|
||||
desc,
|
||||
schema,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(handler, name))
|
||||
decorated = tool(name, desc, schema)(_truncating(handler, name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Read tool for SDK-truncated tool results (always needed, read-only).
|
||||
# Read tool for SDK-truncated tool results (always needed).
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
"""Tests for tool_adapter: truncation, stash, context vars, readOnlyHint annotations."""
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
SDK_DISALLOWED_TOOLS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
@@ -243,7 +244,7 @@ class TestTruncationAndStashIntegration:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_tool_handler (direct execution, no pre-launch)
|
||||
# Parallel pre-launch infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -276,18 +277,169 @@ def _init_ctx(session=None):
|
||||
)
|
||||
|
||||
|
||||
class TestCreateToolHandler:
|
||||
"""Tests for create_tool_handler — direct tool execution."""
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_executes_tool_directly(self):
|
||||
"""Handler executes the tool and returns MCP-formatted result."""
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
@@ -297,9 +449,104 @@ class TestCreateToolHandler:
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_error_on_no_session(self):
|
||||
"""When session is None, handler returns MCP error."""
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
@@ -308,314 +555,220 @@ class TestCreateToolHandler:
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_error_on_exception(self):
|
||||
"""Exception from tool execution is caught and returned as MCP error."""
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_executes_once_per_call(self):
|
||||
"""Each handler call executes the tool exactly once — no duplicate execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="single-execution")
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1"})
|
||||
await handler({"block_id": "b2"})
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests: bugs fixed by removing pre-launch mechanism
|
||||
#
|
||||
# Each test class includes a _buggy_handler fixture that reproduces the old
|
||||
# pre-launch implementation inline. Tests run against BOTH the buggy handler
|
||||
# (xfail — proves the bug exists) and the current clean handler (must pass).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_execute_fn(tool_name: str = "run_block"):
|
||||
"""Return (execute_fn, call_log) — execute_fn records every call."""
|
||||
call_log: list[dict] = []
|
||||
|
||||
async def execute_fn(*args, **kwargs):
|
||||
call_log.append(kwargs)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{len(call_log)}",
|
||||
output=f"result-{len(call_log)}",
|
||||
toolName=tool_name,
|
||||
success=True,
|
||||
)
|
||||
|
||||
return execute_fn, call_log
|
||||
|
||||
|
||||
async def _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args):
|
||||
"""Simulate the OLD buggy pre-launch flow.
|
||||
|
||||
1. pre_launch_tool_call fires _execute_tool_sync with pre_launch_args
|
||||
2. SDK dispatches handler with dispatch_args
|
||||
3. Handler compares args — on mismatch, cancels + re-executes (BUG)
|
||||
|
||||
Returns the handler result.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
|
||||
|
||||
user_id, session = "user-1", _make_mock_session()
|
||||
|
||||
# Step 1: pre-launch fires immediately (speculative)
|
||||
task = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, pre_launch_args)
|
||||
)
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
# Step 2: SDK dispatches with (potentially different) args
|
||||
if pre_launch_args != dispatch_args:
|
||||
# Arg mismatch path: cancel pre-launched task + re-execute
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to direct execution (duplicate!)
|
||||
return await _execute_tool_sync(mock_tool, user_id, session, dispatch_args)
|
||||
else:
|
||||
return await task
|
||||
|
||||
|
||||
class TestBug1DuplicateExecution:
|
||||
"""Bug 1 (SECRT-2204): arg mismatch causes duplicate execution.
|
||||
|
||||
Pre-launch fires with raw args, SDK dispatches with normalised args.
|
||||
Mismatch → cancel (too late) + re-execute → 2 API calls.
|
||||
"""
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old pre-launch code causes duplicate execution")
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_code_duplicates_on_arg_mismatch(self):
|
||||
"""OLD CODE: pre-launch with args A, dispatch with args B → 2 calls."""
|
||||
execute_fn, call_log = _make_execute_fn()
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_fn)
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
|
||||
pre_launch_args = {"block_id": "b1", "input_data": {"title": "Test"}}
|
||||
dispatch_args = {
|
||||
"block_id": "b1",
|
||||
"input_data": {"title": "Test", "priority": None},
|
||||
}
|
||||
|
||||
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
|
||||
|
||||
# BUG: pre-launch executed once + fallback executed again = 2
|
||||
assert len(call_log) == 1, (
|
||||
f"Expected 1 execution but got {len(call_log)} — "
|
||||
f"duplicate execution bug!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_no_duplicate(self):
|
||||
"""FIXED: handler executes exactly once regardless of arg shape."""
|
||||
execute_fn, call_log = _make_execute_fn()
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_fn)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1", "input_data": {"title": "Test"}})
|
||||
|
||||
assert len(call_log) == 1, f"Expected 1 execution but got {len(call_log)}"
|
||||
|
||||
|
||||
class TestBug2FIFODesync:
|
||||
"""Bug 2: FIFO desync when security hook denies a tool.
|
||||
|
||||
Pre-launch queues [task_A, task_B]. Tool A denied (no MCP dispatch).
|
||||
Tool B's handler dequeues task_A → returns wrong result.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old FIFO queue returns wrong result on denial")
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_code_fifo_desync_on_denial(self):
|
||||
"""OLD CODE: denied tool's task stays in queue, next tool gets wrong result."""
|
||||
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
async def tagged_execute(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_log.append(tag)
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"result-for-{tag}",
|
||||
toolName="run_block",
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
|
||||
user_id, session = "user-1", _make_mock_session()
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
|
||||
# Simulate old FIFO queue
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
|
||||
# Pre-launch for tool A and tool B
|
||||
task_a = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "A"})
|
||||
)
|
||||
task_b = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "B"})
|
||||
)
|
||||
queue.put_nowait(task_a)
|
||||
queue.put_nowait(task_b)
|
||||
await asyncio.sleep(0) # let both tasks run
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
|
||||
# Tool A is DENIED by security hook — no MCP dispatch, no dequeue
|
||||
# Tool B's handler dequeues from FIFO → gets task_A!
|
||||
dequeued_task = queue.get_nowait()
|
||||
result = await dequeued_task
|
||||
result_text = result["content"][0]["text"]
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
|
||||
# BUG: handler for B got task_A's result
|
||||
assert "result-for-B" in result_text, (
|
||||
f"Expected result for B but got: {result_text} — "
|
||||
f"FIFO desync: B got A's result!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_no_fifo_desync(self):
|
||||
"""FIXED: each handler call executes independently, no shared queue."""
|
||||
call_log: list[str] = []
|
||||
|
||||
async def tagged_execute(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_log.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"result-for-{tag}",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
|
||||
# Tool A denied (never called). Tool B dispatched normally.
|
||||
result_b = await handler({"block_id": "B"})
|
||||
|
||||
assert "result-for-B" in result_b["content"][0]["text"]
|
||||
assert call_log == ["B"]
|
||||
|
||||
|
||||
class TestBug3CancelRace:
|
||||
"""Bug 3: cancel race — task completes before cancel arrives.
|
||||
|
||||
Pre-launch fires fast HTTP call (< 1s). By the time handler detects
|
||||
mismatch and calls task.cancel(), the API call already completed.
|
||||
Side effect (Linear issue created) is irreversible.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old code: cancel arrives after task completes")
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_code_cancel_arrives_too_late(self):
|
||||
"""OLD CODE: fast task completes before cancel, side effect persists."""
|
||||
side_effects: list[str] = []
|
||||
|
||||
async def fast_execute_with_side_effect(*args, **kwargs):
|
||||
# Side effect happens immediately (like an HTTP POST to Linear)
|
||||
side_effects.append("created-issue")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output="issue-created",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=fast_execute_with_side_effect)
|
||||
|
||||
# Pre-launch fires immediately
|
||||
pre_launch_args = {"block_id": "b1"}
|
||||
dispatch_args = {"block_id": "b1", "extra": "normalised"}
|
||||
|
||||
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
|
||||
|
||||
# BUG: side effect happened TWICE (pre-launch + fallback)
|
||||
assert len(side_effects) == 1, (
|
||||
f"Expected 1 side effect but got {len(side_effects)} — "
|
||||
f"cancel race: pre-launch completed before cancel!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_single_side_effect(self):
|
||||
"""FIXED: no speculative execution, exactly 1 side effect per call."""
|
||||
side_effects: list[str] = []
|
||||
|
||||
async def execute_with_side_effect(*args, **kwargs):
|
||||
side_effects.append("created-issue")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output="issue-created",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_side_effect)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
assert len(side_effects) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# readOnlyHint annotations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadOnlyAnnotations:
|
||||
"""Tests that all tools get readOnlyHint=True for parallel dispatch."""
|
||||
|
||||
def test_parallel_annotation_constant(self):
|
||||
"""_PARALLEL_ANNOTATION is a ToolAnnotations with readOnlyHint=True."""
|
||||
from .tool_adapter import _PARALLEL_ANNOTATION
|
||||
|
||||
assert isinstance(_PARALLEL_ANNOTATION, ToolAnnotations)
|
||||
assert _PARALLEL_ANNOTATION.readOnlyHint is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_DISALLOWED_TOOLS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSDKDisallowedTools:
|
||||
"""Verify that dangerous SDK built-in tools are in the disallowed list."""
|
||||
|
||||
def test_bash_tool_is_disallowed(self):
|
||||
assert "Bash" in SDK_DISALLOWED_TOOLS
|
||||
|
||||
def test_webfetch_tool_is_disallowed(self):
|
||||
"""WebFetch is disallowed due to SSRF risk."""
|
||||
assert "WebFetch" in SDK_DISALLOWED_TOOLS
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
|
||||
@@ -43,10 +43,6 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
# Thinking block types that can be stripped from non-last assistant entries.
|
||||
# The Anthropic API only requires these in the *last* assistant message.
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
@@ -454,83 +450,6 @@ def _build_meta_storage_path(user_id: str, session_id: str, backend: object) ->
|
||||
)
|
||||
|
||||
|
||||
def strip_stale_thinking_blocks(content: str) -> str:
|
||||
"""Remove thinking/redacted_thinking blocks from non-last assistant entries.
|
||||
|
||||
The Anthropic API only requires thinking blocks in the **last** assistant
|
||||
message to be value-identical to the original response. Older assistant
|
||||
entries carry stale thinking blocks that consume significant tokens
|
||||
(often 10-50K each) without providing useful context for ``--resume``.
|
||||
|
||||
Stripping them before upload prevents the CLI from triggering compaction
|
||||
every turn just to compress away the stale thinking bloat.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
for line in lines:
|
||||
parsed.append((line, json.loads(line, fallback=None)))
|
||||
|
||||
# Reverse scan to find the last assistant message ID and index.
|
||||
last_asst_msg_id: str | None = None
|
||||
last_asst_idx: int | None = None
|
||||
for i in range(len(parsed) - 1, -1, -1):
|
||||
_line, entry = parsed[i]
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_asst_msg_id = msg.get("id")
|
||||
last_asst_idx = i
|
||||
break
|
||||
|
||||
if last_asst_idx is None:
|
||||
return content
|
||||
|
||||
result_lines: list[str] = []
|
||||
stripped_count = 0
|
||||
for i, (line, entry) in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
|
||||
msg = entry.get("message", {})
|
||||
# Only strip from assistant entries that are NOT the last turn.
|
||||
# Use msg_id matching when available; fall back to index for entries
|
||||
# without an id field.
|
||||
is_last_turn = (
|
||||
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
|
||||
) or (last_asst_msg_id is None and i == last_asst_idx)
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and not is_last_turn
|
||||
and isinstance(msg.get("content"), list)
|
||||
):
|
||||
content_blocks = msg["content"]
|
||||
filtered = [
|
||||
b
|
||||
for b in content_blocks
|
||||
if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES)
|
||||
]
|
||||
if len(filtered) < len(content_blocks):
|
||||
stripped_count += len(content_blocks) - len(filtered)
|
||||
entry = {**entry, "message": {**msg, "content": filtered}}
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
|
||||
if stripped_count:
|
||||
logger.info(
|
||||
"[Transcript] Stripped %d stale thinking block(s) from non-last entries",
|
||||
stripped_count,
|
||||
)
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
@@ -553,9 +472,6 @@ async def upload_transcript(
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
# Strip stale thinking blocks from older assistant entries — these consume
|
||||
# significant tokens and trigger unnecessary CLI compaction every turn.
|
||||
stripped = strip_stale_thinking_blocks(stripped)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
@@ -689,6 +605,9 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
@@ -714,14 +633,11 @@ def _flatten_assistant_content(blocks: list) -> str:
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
# Drop tool_use entirely — any text representation gets
|
||||
# mimicked by the model as plain text instead of actual
|
||||
# structured tool calls. The tool results (in the
|
||||
# following user/tool_result entry) provide sufficient
|
||||
# context about what happened.
|
||||
continue
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
continue
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
@@ -13,7 +13,6 @@ from .transcript import (
|
||||
delete_transcript,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
@@ -1201,170 +1200,3 @@ class TestCleanupStaleProjectDirs:
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# strip_stale_thinking_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStripStaleThinkingBlocks:
|
||||
"""Tests for strip_stale_thinking_blocks — removes thinking/redacted_thinking
|
||||
blocks from non-last assistant entries to reduce transcript bloat."""
|
||||
|
||||
def _asst_entry(
|
||||
self, msg_id: str, content: list, uuid: str = "u1", parent: str = ""
|
||||
) -> dict:
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": uuid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": msg_id,
|
||||
"type": "message",
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
def _user_entry(self, text: str, uuid: str = "u0", parent: str = "") -> dict:
|
||||
return {
|
||||
"type": "user",
|
||||
"uuid": uuid,
|
||||
"parentUuid": parent,
|
||||
"message": {"role": "user", "content": text},
|
||||
}
|
||||
|
||||
def test_strips_thinking_from_older_assistant(self) -> None:
|
||||
"""Thinking blocks in non-last assistant entries should be removed."""
|
||||
old_asst = self._asst_entry(
|
||||
"msg_old",
|
||||
[
|
||||
{"type": "thinking", "thinking": "deep thoughts..."},
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "redacted_thinking", "data": "secret"},
|
||||
],
|
||||
uuid="a1",
|
||||
)
|
||||
new_asst = self._asst_entry(
|
||||
"msg_new",
|
||||
[
|
||||
{"type": "thinking", "thinking": "latest thoughts"},
|
||||
{"type": "text", "text": "world"},
|
||||
],
|
||||
uuid="a2",
|
||||
parent="a1",
|
||||
)
|
||||
content = _make_jsonl(old_asst, new_asst)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# Old assistant should have thinking blocks stripped
|
||||
old_content = lines[0]["message"]["content"]
|
||||
assert len(old_content) == 1
|
||||
assert old_content[0]["type"] == "text"
|
||||
|
||||
# New (last) assistant should be untouched
|
||||
new_content = lines[1]["message"]["content"]
|
||||
assert len(new_content) == 2
|
||||
assert new_content[0]["type"] == "thinking"
|
||||
assert new_content[1]["type"] == "text"
|
||||
|
||||
def test_preserves_last_assistant_thinking(self) -> None:
|
||||
"""The last assistant entry's thinking blocks must be preserved."""
|
||||
entry = self._asst_entry(
|
||||
"msg_only",
|
||||
[
|
||||
{"type": "thinking", "thinking": "must keep"},
|
||||
{"type": "text", "text": "response"},
|
||||
],
|
||||
)
|
||||
content = _make_jsonl(entry)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(lines[0]["message"]["content"]) == 2
|
||||
|
||||
def test_no_assistant_entries_returns_unchanged(self) -> None:
|
||||
"""Transcripts with only user entries should pass through unchanged."""
|
||||
user = self._user_entry("hello")
|
||||
content = _make_jsonl(user)
|
||||
assert strip_stale_thinking_blocks(content) == content
|
||||
|
||||
def test_empty_content_returns_unchanged(self) -> None:
|
||||
assert strip_stale_thinking_blocks("") == ""
|
||||
|
||||
def test_multiple_turns_strips_all_but_last(self) -> None:
|
||||
"""With 3 assistant turns, only the last keeps thinking blocks."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_1",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t1"},
|
||||
{"type": "text", "text": "a1"},
|
||||
],
|
||||
uuid="a1",
|
||||
),
|
||||
self._user_entry("q2", uuid="u2", parent="a1"),
|
||||
self._asst_entry(
|
||||
"msg_2",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t2"},
|
||||
{"type": "text", "text": "a2"},
|
||||
],
|
||||
uuid="a2",
|
||||
parent="u2",
|
||||
),
|
||||
self._user_entry("q3", uuid="u3", parent="a2"),
|
||||
self._asst_entry(
|
||||
"msg_3",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t3"},
|
||||
{"type": "text", "text": "a3"},
|
||||
],
|
||||
uuid="a3",
|
||||
parent="u3",
|
||||
),
|
||||
]
|
||||
content = _make_jsonl(*entries)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# msg_1: thinking stripped
|
||||
assert len(lines[0]["message"]["content"]) == 1
|
||||
assert lines[0]["message"]["content"][0]["type"] == "text"
|
||||
# msg_2: thinking stripped
|
||||
assert len(lines[2]["message"]["content"]) == 1
|
||||
# msg_3 (last): thinking preserved
|
||||
assert len(lines[4]["message"]["content"]) == 2
|
||||
assert lines[4]["message"]["content"][0]["type"] == "thinking"
|
||||
|
||||
def test_same_msg_id_multi_entry_turn(self) -> None:
|
||||
"""Multiple entries sharing the same message.id (same turn) are preserved."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_old",
|
||||
[{"type": "thinking", "thinking": "old"}],
|
||||
uuid="a1",
|
||||
),
|
||||
self._asst_entry(
|
||||
"msg_last",
|
||||
[{"type": "thinking", "thinking": "t_part1"}],
|
||||
uuid="a2",
|
||||
parent="a1",
|
||||
),
|
||||
self._asst_entry(
|
||||
"msg_last",
|
||||
[{"type": "text", "text": "response"}],
|
||||
uuid="a3",
|
||||
parent="a2",
|
||||
),
|
||||
]
|
||||
content = _make_jsonl(*entries)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# Old entry stripped
|
||||
assert lines[0]["message"]["content"] == []
|
||||
# Both entries of last turn (msg_last) preserved
|
||||
assert lines[1]["message"]["content"][0]["type"] == "thinking"
|
||||
assert lines[2]["message"]["content"][0]["type"] == "text"
|
||||
|
||||
@@ -30,7 +30,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
if not cfg.claude_agent_use_resume:
|
||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -221,21 +221,9 @@ async def create_session(
|
||||
return session
|
||||
|
||||
|
||||
_meta_ttl_refresh_at: dict[str, float] = {}
|
||||
"""Tracks the last time the session meta key TTL was refreshed.
|
||||
|
||||
Used by `publish_chunk` to avoid refreshing on every single chunk
|
||||
(expensive). Refreshes at most once every 60 seconds per session.
|
||||
"""
|
||||
|
||||
_META_TTL_REFRESH_INTERVAL = 60 # seconds
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
turn_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream.
|
||||
|
||||
@@ -244,9 +232,6 @@ async def publish_chunk(
|
||||
Args:
|
||||
turn_id: Turn ID (per-turn UUID) identifying the stream
|
||||
chunk: The stream response chunk to publish
|
||||
session_id: Chat session ID — when provided, the session meta key
|
||||
TTL is refreshed periodically to prevent expiration during
|
||||
long-running turns (see SECRT-2178).
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
@@ -280,23 +265,6 @@ async def publish_chunk(
|
||||
# Set TTL on stream to match session metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
|
||||
# Periodically refresh session-related TTLs so they don't expire
|
||||
# during long-running turns. Without this, turns exceeding stream_ttl
|
||||
# (default 1h) lose their "running" status and stream data, making
|
||||
# the session invisible to the resume endpoint (empty on page reload).
|
||||
# Both meta key AND stream key are refreshed: the stream key's expire
|
||||
# above only fires when publish_chunk is called, but during long
|
||||
# sub-agent gaps (task_progress events don't produce chunks), neither
|
||||
# key gets refreshed.
|
||||
if session_id:
|
||||
now = time.perf_counter()
|
||||
last_refresh = _meta_ttl_refresh_at.get(session_id, 0)
|
||||
if now - last_refresh >= _META_TTL_REFRESH_INTERVAL:
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
_meta_ttl_refresh_at[session_id] = now
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
@@ -363,7 +331,7 @@ async def stream_and_publish(
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event, session_id=session_id)
|
||||
await publish_chunk(turn_id, event)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
@@ -832,9 +800,6 @@ async def mark_session_completed(
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
|
||||
_meta_ttl_refresh_at.pop(session_id, None)
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
@@ -10,7 +10,6 @@ from backend.copilot.tracking import track_tool_called
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .ask_question import AskQuestionTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
@@ -56,7 +55,6 @@ logger = logging.getLogger(__name__)
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"ask_question": AskQuestionTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
|
||||
@@ -68,9 +68,6 @@ class AddUnderstandingTool(BaseTool):
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
|
||||
Note: This tool accepts **kwargs because its parameters are derived
|
||||
dynamically from the BusinessUnderstandingInput model schema.
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
@@ -80,21 +77,23 @@ class AddUnderstandingTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
filtered = {k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in filtered.values()):
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
input_data = BusinessUnderstandingInput(**filtered)
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [k for k, v in filtered.items() if v is not None]
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
|
||||
@@ -180,14 +180,12 @@ async def _save_browser_state(
|
||||
"""
|
||||
try:
|
||||
# Gather state in parallel
|
||||
(
|
||||
(rc_url, url_out, _),
|
||||
(rc_ck, ck_out, _),
|
||||
(rc_ls, ls_out, _),
|
||||
) = await asyncio.gather(
|
||||
_run(session_name, "get", "url", timeout=10),
|
||||
_run(session_name, "cookies", "get", "--json", timeout=10),
|
||||
_run(session_name, "storage", "local", "--json", timeout=10),
|
||||
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
|
||||
await asyncio.gather(
|
||||
_run(session_name, "get", "url", timeout=10),
|
||||
_run(session_name, "cookies", "get", "--json", timeout=10),
|
||||
_run(session_name, "storage", "local", "--json", timeout=10),
|
||||
)
|
||||
)
|
||||
|
||||
state = {
|
||||
@@ -450,8 +448,6 @@ class BrowserNavigateTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
url: str = "",
|
||||
wait_for: str = "networkidle",
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
|
||||
@@ -460,8 +456,8 @@ class BrowserNavigateTool(BaseTool):
|
||||
Note: for slow SPAs that never fully idle, the snapshot may reflect a
|
||||
partially-loaded state (the wait is best-effort).
|
||||
"""
|
||||
url = url.strip()
|
||||
wait_for = wait_for or "networkidle"
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
wait_for: str = kwargs.get("wait_for") or "networkidle"
|
||||
session_name = session.session_id
|
||||
|
||||
if not url:
|
||||
@@ -616,10 +612,6 @@ class BrowserActTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
action: str = "",
|
||||
target: str = "",
|
||||
value: str = "",
|
||||
direction: str = "down",
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Perform a browser action and return an updated page snapshot.
|
||||
@@ -628,10 +620,10 @@ class BrowserActTool(BaseTool):
|
||||
``agent-browser``, waits for the page to settle, and returns the
|
||||
accessibility-tree snapshot so the LLM can plan the next step.
|
||||
"""
|
||||
action = action.strip()
|
||||
target = target.strip()
|
||||
value = value.strip()
|
||||
direction = direction.strip()
|
||||
action: str = (kwargs.get("action") or "").strip()
|
||||
target: str = (kwargs.get("target") or "").strip()
|
||||
value: str = (kwargs.get("value") or "").strip()
|
||||
direction: str = (kwargs.get("direction") or "down").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
if not action:
|
||||
@@ -785,8 +777,6 @@ class BrowserScreenshotTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
annotate: bool | str = True,
|
||||
filename: str = "screenshot.png",
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Capture a PNG screenshot and upload it to the workspace.
|
||||
@@ -796,12 +786,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
Returns a :class:`BrowserScreenshotResponse` with the workspace
|
||||
``file_id`` the LLM should pass to ``read_workspace_file``.
|
||||
"""
|
||||
raw_annotate = annotate
|
||||
raw_annotate = kwargs.get("annotate", True)
|
||||
if isinstance(raw_annotate, str):
|
||||
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
|
||||
else:
|
||||
annotate = bool(raw_annotate)
|
||||
filename = filename.strip()
|
||||
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
|
||||
@@ -411,12 +411,7 @@ class AgentOutputTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool.
|
||||
|
||||
Note: This tool accepts **kwargs and delegates to AgentOutputInput
|
||||
for validation because the parameter set has cross-field validators
|
||||
defined in the Pydantic model.
|
||||
"""
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"""AskQuestionTool - Ask the user a clarifying question before proceeding."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ClarificationNeededResponse, ClarifyingQuestion, ToolResponseBase
|
||||
|
||||
|
||||
class AskQuestionTool(BaseTool):
|
||||
"""Ask the user a clarifying question and wait for their answer.
|
||||
|
||||
Use this tool when the user's request is ambiguous and you need more
|
||||
information before proceeding. Call find_block or other discovery tools
|
||||
first to ground your question in real platform options, then call this
|
||||
tool with a concrete question listing those options.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_question"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Ask the user a clarifying question. Use when the request is "
|
||||
"ambiguous and you need to confirm intent, choose between options, "
|
||||
"or gather missing details before proceeding."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete question to ask the user. Should list "
|
||||
"real options when applicable."
|
||||
),
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Options for the user to choose from "
|
||||
"(e.g. ['Email', 'Slack', 'Google Docs'])."
|
||||
),
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "Short label identifying what the question is about.",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # unused; required by BaseTool contract
|
||||
question_raw = kwargs.get("question")
|
||||
if not isinstance(question_raw, str) or not question_raw.strip():
|
||||
raise ValueError("ask_question requires a non-empty 'question' string")
|
||||
question = question_raw.strip()
|
||||
raw_options = kwargs.get("options", [])
|
||||
if not isinstance(raw_options, list):
|
||||
raw_options = []
|
||||
options: list[str] = [str(o) for o in raw_options if o]
|
||||
raw_keyword = kwargs.get("keyword", "")
|
||||
keyword: str = str(raw_keyword) if raw_keyword else ""
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
example = ", ".join(options) if options else None
|
||||
clarifying_question = ClarifyingQuestion(
|
||||
question=question,
|
||||
keyword=keyword,
|
||||
example=example,
|
||||
)
|
||||
return ClarificationNeededResponse(
|
||||
message=question,
|
||||
session_id=session_id,
|
||||
questions=[clarifying_question],
|
||||
)
|
||||
@@ -1,99 +0,0 @@
|
||||
"""Tests for AskQuestionTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.ask_question import AskQuestionTool
|
||||
from backend.copilot.tools.models import ClarificationNeededResponse
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> AskQuestionTool:
|
||||
return AskQuestionTool()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_options(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="Which channel?",
|
||||
options=["Email", "Slack", "Google Docs"],
|
||||
keyword="channel",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert result.message == "Which channel?"
|
||||
assert result.session_id == session.session_id
|
||||
assert len(result.questions) == 1
|
||||
|
||||
q = result.questions[0]
|
||||
assert q.question == "Which channel?"
|
||||
assert q.keyword == "channel"
|
||||
assert q.example == "Email, Slack, Google Docs"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_without_options(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="What format do you want?",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert result.message == "What format do you want?"
|
||||
assert len(result.questions) == 1
|
||||
|
||||
q = result.questions[0]
|
||||
assert q.question == "What format do you want?"
|
||||
assert q.keyword == ""
|
||||
assert q.example is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_keyword_only(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="How often should it run?",
|
||||
keyword="trigger",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
q = result.questions[0]
|
||||
assert q.keyword == "trigger"
|
||||
assert q.example is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_rejects_empty_question(
|
||||
tool: AskQuestionTool, session: ChatSession
|
||||
):
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
await tool._execute(user_id=None, session=session, question="")
|
||||
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
await tool._execute(user_id=None, session=session, question=" ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_coerces_invalid_options(
|
||||
tool: AskQuestionTool, session: ChatSession
|
||||
):
|
||||
"""LLM may send options as a string instead of a list; should not crash."""
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="Pick one",
|
||||
options="not-a-list", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
q = result.questions[0]
|
||||
assert q.example is None
|
||||
@@ -76,8 +76,6 @@ class BashExecTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
command: str = "",
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
|
||||
@@ -90,8 +88,8 @@ class BashExecTool(BaseTool):
|
||||
"""
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
command = command.strip()
|
||||
timeout = int(timeout)
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = int(kwargs.get("timeout", 30))
|
||||
|
||||
if not command:
|
||||
return ErrorResponse(
|
||||
|
||||
@@ -115,9 +115,6 @@ class ConnectIntegrationTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
provider: str = "",
|
||||
reason: str = "",
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Build and return a :class:`SetupRequirementsResponse` for the requested provider.
|
||||
@@ -131,10 +128,12 @@ class ConnectIntegrationTool(BaseTool):
|
||||
"""
|
||||
_ = user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider = (provider or "").strip().lower()
|
||||
reason = (reason or "").strip()[:500] # cap LLM-controlled text
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (scopes or []) if str(s).strip()
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
entry = SUPPORTED_PROVIDERS.get(provider)
|
||||
@@ -142,7 +141,8 @@ class ConnectIntegrationTool(BaseTool):
|
||||
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDERS)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. Supported providers: {supported}."
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
@@ -153,11 +153,11 @@ class ConnectIntegrationTool(BaseTool):
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = entry["default_scopes"]
|
||||
seen: set[str] = set()
|
||||
merged_scopes: list[str] = []
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
merged_scopes.append(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
@@ -171,7 +171,7 @@ class ConnectIntegrationTool(BaseTool):
|
||||
"title": f"{display_name} Credentials",
|
||||
"provider": provider,
|
||||
"types": supported_types,
|
||||
"scopes": merged_scopes,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
|
||||
@@ -53,10 +53,11 @@ class ContinueRunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
review_id: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
review_id = review_id.strip() if review_id else ""
|
||||
review_id = (
|
||||
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if not review_id:
|
||||
|
||||
@@ -62,12 +62,9 @@ class CreateAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_json: dict[str, Any] | None = None,
|
||||
save: bool = True,
|
||||
library_agent_ids: list[str] | None = None,
|
||||
folder_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json:
|
||||
@@ -80,8 +77,9 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if library_agent_ids is None:
|
||||
library_agent_ids = []
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
|
||||
@@ -61,12 +61,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_json: dict[str, Any] | None = None,
|
||||
save: bool = True,
|
||||
library_agent_ids: list[str] | None = None,
|
||||
folder_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json:
|
||||
@@ -78,8 +75,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if library_agent_ids is None:
|
||||
library_agent_ids = []
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
|
||||
@@ -62,15 +62,10 @@ class EditAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_id: str = "",
|
||||
agent_json: dict[str, Any] | None = None,
|
||||
save: bool = True,
|
||||
library_agent_ids: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_id = agent_id.strip()
|
||||
if library_agent_ids is None:
|
||||
library_agent_ids = []
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
@@ -89,6 +84,9 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
|
||||
@@ -157,10 +157,9 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
query = (query or "").strip()
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
@@ -289,13 +288,11 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
existing_issue_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
title = (title or "").strip()
|
||||
description = (description or "").strip()
|
||||
title = kwargs.get("title", "").strip()
|
||||
description = kwargs.get("description", "").strip()
|
||||
existing_issue_id = kwargs.get("existing_issue_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not title or not description:
|
||||
|
||||
@@ -34,15 +34,11 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search marketplace for agents matching the query."""
|
||||
return await search_agents(
|
||||
query=query.strip(),
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -86,8 +86,6 @@ class FindBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
include_schemas: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
@@ -96,14 +94,14 @@ class FindBlockTool(BaseTool):
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
include_schemas: Whether to include block schemas in results
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = (query or "").strip()
|
||||
query = kwargs.get("query", "").strip()
|
||||
include_schemas = kwargs.get("include_schemas", False)
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
|
||||
@@ -41,14 +41,10 @@ class FindLibraryAgentTool(BaseTool):
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=query.strip(),
|
||||
query=(kwargs.get("query") or "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -51,9 +51,9 @@ class FixAgentGraphTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_json: dict | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
@@ -98,7 +98,8 @@ class FixAgentGraphTool(BaseTool):
|
||||
if is_valid:
|
||||
return FixResultResponse(
|
||||
message=(
|
||||
f"Applied {len(fixes_applied)} fix(es). Agent graph is now valid!"
|
||||
f"Applied {len(fixes_applied)} fix(es). "
|
||||
"Agent graph is now valid!"
|
||||
),
|
||||
fixed_agent_json=fixed_agent,
|
||||
fixes_applied=fixes_applied,
|
||||
|
||||
@@ -60,7 +60,7 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
|
||||
@@ -68,7 +68,6 @@ class GetDocPageTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
path: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
@@ -82,7 +81,7 @@ class GetDocPageTool(BaseTool):
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = path.strip()
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
|
||||
@@ -56,7 +56,7 @@ class GetMCPGuideTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,7 @@ async def execute_block(
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
dry_run: bool,
|
||||
dry_run: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
@@ -114,9 +114,11 @@ async def execute_block(
|
||||
error=sim_error[0],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
message=(
|
||||
f"[DRY RUN] Block '{block.name}' simulated successfully "
|
||||
"— no real execution occurred."
|
||||
),
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
@@ -335,7 +337,7 @@ async def prepare_block_for_execution(
|
||||
user_id: str,
|
||||
session: ChatSession,
|
||||
session_id: str,
|
||||
dry_run: bool,
|
||||
dry_run: bool = False,
|
||||
) -> "BlockPreparation | ToolResponseBase":
|
||||
"""Validate and prepare a block for execution.
|
||||
|
||||
|
||||
@@ -102,7 +102,6 @@ class TestExecuteBlockCreditCharging:
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
@@ -133,7 +132,6 @@ class TestExecuteBlockCreditCharging:
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
@@ -160,7 +158,6 @@ class TestExecuteBlockCreditCharging:
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
@@ -197,7 +194,6 @@ class TestExecuteBlockCreditCharging:
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Block already executed (with side effects), so output is returned
|
||||
@@ -281,7 +277,6 @@ async def test_coerce_json_string_to_nested_list():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -322,7 +317,6 @@ async def test_coerce_json_string_to_list():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-2",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -355,7 +349,6 @@ async def test_coerce_json_string_to_dict():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-3",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -389,7 +382,6 @@ async def test_no_coercion_when_type_matches():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-4",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -423,7 +415,6 @@ async def test_coerce_string_to_int():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-5",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -457,7 +448,6 @@ async def test_coerce_skips_none_values():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-6",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -491,7 +481,6 @@ async def test_coerce_union_type_preserves_valid_member():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-7",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -527,7 +516,6 @@ async def test_coerce_inner_elements_of_generic():
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-8",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -604,7 +592,6 @@ async def test_prepare_block_not_found() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message
|
||||
@@ -625,7 +612,6 @@ async def test_prepare_block_disabled() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "disabled" in result.message
|
||||
@@ -654,7 +640,6 @@ async def test_prepare_block_unrecognized_fields() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, InputValidationErrorResponse)
|
||||
assert "unknown_field" in result.unrecognized_fields
|
||||
@@ -684,7 +669,6 @@ async def test_prepare_block_missing_credentials() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
@@ -714,7 +698,6 @@ async def test_prepare_block_success_returns_preparation() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, BlockPreparation)
|
||||
assert result.required_non_credential_keys == {"text"}
|
||||
@@ -819,7 +802,6 @@ async def test_prepare_block_excluded_by_type() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
@@ -842,7 +824,6 @@ async def test_prepare_block_excluded_by_id() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
@@ -876,7 +857,6 @@ async def test_prepare_block_file_ref_expansion_error() -> None:
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
dry_run=False,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "file reference" in result.message.lower()
|
||||
|
||||
@@ -866,7 +866,6 @@ class TestRunBlockToolAuthenticatedHttp:
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={"url": "https://api.example.com/data", "method": "GET"},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, SetupRequirementsResponse)
|
||||
@@ -908,7 +907,6 @@ class TestRunBlockToolAuthenticatedHttp:
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
|
||||
@@ -120,18 +120,14 @@ class CreateFolderTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
name: str = "",
|
||||
parent_id: str | None = None,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Create a folder with the given name and optional parent/icon/color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
name = (name or "").strip()
|
||||
name = (kwargs.get("name") or "").strip()
|
||||
parent_id = kwargs.get("parent_id")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not name:
|
||||
@@ -200,15 +196,12 @@ class ListFoldersTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
parent_id: str | None = None,
|
||||
include_agents: bool = False,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""List folders as a flat list (by parent) or full tree."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
parent_id = kwargs.get("parent_id")
|
||||
include_agents = kwargs.get("include_agents", False)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
try:
|
||||
@@ -300,18 +293,14 @@ class UpdateFolderTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
folder_id: str = "",
|
||||
name: str | None = None,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Update a folder's name, icon, or color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (folder_id or "").strip()
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
name = kwargs.get("name")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
@@ -376,16 +365,12 @@ class MoveFolderTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
folder_id: str = "",
|
||||
target_parent_id: str | None = None,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move a folder to a new parent or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (folder_id or "").strip()
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
target_parent_id = kwargs.get("target_parent_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
@@ -446,15 +431,11 @@ class DeleteFolderTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
folder_id: str = "",
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Soft-delete a folder; agents inside are moved to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (folder_id or "").strip()
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
@@ -518,17 +499,12 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_ids: list[str] | None = None,
|
||||
folder_id: str | None = None,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move one or more agents to a folder or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
if agent_ids is None:
|
||||
agent_ids = []
|
||||
agent_ids = kwargs.get("agent_ids", [])
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_ids:
|
||||
|
||||
@@ -71,7 +71,7 @@ class RunAgentInput(BaseModel):
|
||||
cron: str = ""
|
||||
timezone: str = "UTC"
|
||||
wait_for_result: int = Field(default=0, ge=0, le=300)
|
||||
dry_run: bool
|
||||
dry_run: bool = False
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug",
|
||||
@@ -153,10 +153,14 @@ class RunAgentTool(BaseTool):
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": "Execute in preview mode.",
|
||||
"description": (
|
||||
"When true, simulates the entire agent execution using an LLM "
|
||||
"for each block — no real API calls, no credentials needed, "
|
||||
"no credits charged. Useful for testing agent wiring end-to-end."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["dry_run"],
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -170,16 +174,8 @@ class RunAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the tool with automatic state detection.
|
||||
|
||||
Note: This tool accepts **kwargs and delegates to RunAgentInput for
|
||||
validation because the parameter set is complex with cross-field
|
||||
validators defined in the Pydantic model.
|
||||
"""
|
||||
"""Execute the tool with automatic state detection."""
|
||||
params = RunAgentInput(**kwargs)
|
||||
# Session-level dry_run forces all tool calls to use dry-run mode.
|
||||
if session.dry_run:
|
||||
params.dry_run = True
|
||||
session_id = session.session_id
|
||||
|
||||
# Validate at least one identifier is provided
|
||||
@@ -205,18 +201,6 @@ class RunAgentTool(BaseTool):
|
||||
# Determine if this is a schedule request
|
||||
is_schedule = bool(params.schedule_name or params.cron)
|
||||
|
||||
# Session-level dry-run blocks scheduling — schedules create real
|
||||
# side effects that cannot be simulated.
|
||||
if params.dry_run and is_schedule:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Scheduling is disabled in dry-run mode because it creates "
|
||||
"real side effects. Remove cron/schedule_name to simulate "
|
||||
"a run, or disable dry-run to create a real schedule."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Fetch agent details
|
||||
graph: GraphModel | None = None
|
||||
@@ -474,8 +458,8 @@ class RunAgentTool(BaseTool):
|
||||
graph: GraphModel,
|
||||
graph_credentials: dict[str, CredentialsMetaInput],
|
||||
inputs: dict[str, Any],
|
||||
dry_run: bool,
|
||||
wait_for_result: int = 0,
|
||||
dry_run: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent immediately, optionally waiting for completion."""
|
||||
session_id = session.session_id
|
||||
|
||||
@@ -53,7 +53,6 @@ async def test_run_agent(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "Hello World"},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -94,7 +93,6 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={}, # Missing required input
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -127,7 +125,6 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid/agent-id",
|
||||
inputs={"test_input": "Hello World"},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -168,7 +165,6 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -207,7 +203,6 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
use_defaults=False,
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -243,7 +238,6 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
use_defaults=True,
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -274,7 +268,6 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"url": "https://example.com"},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -307,7 +300,6 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="no-slash-here",
|
||||
inputs={},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -335,7 +327,6 @@ async def test_run_agent_unauthenticated():
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="test/test-agent",
|
||||
inputs={},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -368,7 +359,6 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
inputs={"test_input": "test"},
|
||||
schedule_name="My Schedule",
|
||||
cron="", # Empty cron
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -401,7 +391,6 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
inputs={"test_input": "test"},
|
||||
schedule_name="", # Empty name
|
||||
cron="0 9 * * *",
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -435,7 +424,6 @@ async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
dry_run=False,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
@@ -51,10 +51,14 @@ class RunBlockTool(BaseTool):
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": "Execute in preview mode.",
|
||||
"description": (
|
||||
"When true, simulates block execution using an LLM without making any "
|
||||
"real API calls or producing side effects. Useful for testing agent "
|
||||
"wiring and previewing outputs. Default: false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data", "dry_run"],
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -65,10 +69,6 @@ class RunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
block_id: str = "",
|
||||
input_data: dict | None = None,
|
||||
dry_run: bool,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
@@ -78,19 +78,15 @@ class RunBlockTool(BaseTool):
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
dry_run: If True, simulate execution without side effects
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = block_id.strip()
|
||||
if input_data is None:
|
||||
input_data = {}
|
||||
# Session-level dry_run forces all tool calls to use dry-run mode.
|
||||
if session.dry_run:
|
||||
dry_run = True
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
dry_run = bool(kwargs.get("dry_run", False))
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
|
||||
@@ -103,7 +103,6 @@ class TestRunBlockFiltering:
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
@@ -130,7 +129,6 @@ class TestRunBlockFiltering:
|
||||
session=session,
|
||||
block_id=orchestrator_id,
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
@@ -156,7 +154,6 @@ class TestRunBlockFiltering:
|
||||
session=session,
|
||||
block_id=block_id,
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
finally:
|
||||
_current_permissions.reset(token)
|
||||
@@ -190,7 +187,6 @@ class TestRunBlockFiltering:
|
||||
session=session,
|
||||
block_id=block_id,
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
finally:
|
||||
_current_permissions.reset(token)
|
||||
@@ -226,7 +222,6 @@ class TestRunBlockFiltering:
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
@@ -287,7 +282,6 @@ class TestRunBlockInputValidation:
|
||||
"prompt": "Write a haiku about coding",
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
@@ -333,7 +327,6 @@ class TestRunBlockInputValidation:
|
||||
"system_prompt": "Be helpful",
|
||||
"retries": 5,
|
||||
},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
@@ -377,7 +370,6 @@ class TestRunBlockInputValidation:
|
||||
input_data={
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
@@ -432,7 +424,6 @@ class TestRunBlockInputValidation:
|
||||
"prompt": "Write a haiku",
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -472,7 +463,6 @@ class TestRunBlockInputValidation:
|
||||
input_data={
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
@@ -524,7 +514,6 @@ class TestRunBlockSensitiveAction:
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ReviewRequiredResponse)
|
||||
@@ -585,7 +574,6 @@ class TestRunBlockSensitiveAction:
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
@@ -640,7 +628,6 @@ class TestRunBlockSensitiveAction:
|
||||
session=session,
|
||||
block_id="http-request-id",
|
||||
input_data=input_data,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
|
||||
@@ -91,40 +91,21 @@ class RunMCPToolTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
server_url: str = "",
|
||||
tool_name: str = "",
|
||||
tool_arguments: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
server_url = server_url.strip()
|
||||
tool_name = tool_name.strip()
|
||||
server_url: str = (kwargs.get("server_url") or "").strip()
|
||||
tool_name: str = (kwargs.get("tool_name") or "").strip()
|
||||
raw_tool_arguments = kwargs.get("tool_arguments")
|
||||
tool_arguments: dict[str, Any] = (
|
||||
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
# Session-level dry_run prevents real MCP tool execution.
|
||||
# Discovery (no tool_name) is still allowed so the agent can inspect
|
||||
# available tools, but actual execution is blocked.
|
||||
if session.dry_run and tool_name:
|
||||
return MCPToolOutputResponse(
|
||||
message=(
|
||||
f"[dry-run] MCP tool '{tool_name}' on "
|
||||
f"{server_host(server_url)} was not executed "
|
||||
"because the session is in dry-run mode."
|
||||
),
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
result=None,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if tool_arguments is not None and not isinstance(tool_arguments, dict):
|
||||
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
|
||||
return ErrorResponse(
|
||||
message="tool_arguments must be a JSON object.",
|
||||
session_id=session_id,
|
||||
)
|
||||
resolved_tool_arguments: dict[str, Any] = (
|
||||
tool_arguments if isinstance(tool_arguments, dict) else {}
|
||||
)
|
||||
|
||||
if not server_url:
|
||||
return ErrorResponse(
|
||||
@@ -186,7 +167,7 @@ class RunMCPToolTool(BaseTool):
|
||||
else:
|
||||
# Stage 2: Execute the selected tool
|
||||
return await self._execute_tool(
|
||||
client, server_url, tool_name, resolved_tool_arguments, session_id
|
||||
client, server_url, tool_name, tool_arguments, session_id
|
||||
)
|
||||
|
||||
except HTTPClientError as e:
|
||||
|
||||
@@ -85,7 +85,6 @@ class SearchDocsTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
@@ -100,7 +99,7 @@ class SearchDocsTool(BaseTool):
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = query.strip()
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
|
||||
@@ -73,10 +73,7 @@ def make_openai_response(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_basic():
|
||||
"""simulate_block returns correct (output_name, output_data) tuples.
|
||||
|
||||
Empty "error" pins are dropped at source — only non-empty errors are yielded.
|
||||
"""
|
||||
"""simulate_block returns correct (output_name, output_data) tuples."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
@@ -91,8 +88,7 @@ async def test_simulate_block_basic():
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated output") in outputs
|
||||
# Empty error pin is dropped at the simulator level
|
||||
assert ("error", "") not in outputs
|
||||
assert ("error", "") in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -117,8 +113,6 @@ async def test_simulate_block_json_retry():
|
||||
|
||||
assert mock_client.chat.completions.create.call_count == 3
|
||||
assert ("result", "ok") in outputs
|
||||
# Empty error pin is dropped
|
||||
assert ("error", "") not in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -147,7 +141,7 @@ async def test_simulate_block_all_retries_exhausted():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_missing_output_pins():
|
||||
"""LLM response missing some output pins; verify non-error pins filled with None."""
|
||||
"""LLM response missing some output pins; verify they're filled with None."""
|
||||
mock_block = make_mock_block(
|
||||
output_props={
|
||||
"result": {"type": "string"},
|
||||
@@ -170,29 +164,7 @@ async def test_simulate_block_missing_output_pins():
|
||||
|
||||
assert outputs["result"] == "hello"
|
||||
assert outputs["count"] is None # missing pin filled with None
|
||||
assert "error" not in outputs # missing error pin is omitted entirely
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_keeps_nonempty_error():
|
||||
"""simulate_block keeps non-empty error pins (simulated logical errors)."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=make_openai_response(
|
||||
'{"result": "", "error": "API rate limit exceeded"}'
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {"query": "test"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "") in outputs
|
||||
assert ("error", "API rate limit exceeded") in outputs
|
||||
assert outputs["error"] == "" # "error" pin filled with ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -228,19 +200,6 @@ async def test_simulate_block_truncates_long_inputs():
|
||||
assert len(parsed["text"]) < 25000
|
||||
|
||||
|
||||
def test_build_simulation_prompt_excludes_error_from_must_include():
|
||||
"""The 'MUST include' prompt line should NOT list 'error' — the prompt
|
||||
already instructs the LLM to OMIT error unless simulating a logical error.
|
||||
Including it in 'MUST include' would be contradictory."""
|
||||
block = make_mock_block() # default output_props has "result" and "error"
|
||||
system_prompt, _ = build_simulation_prompt(block, {"query": "test"})
|
||||
must_include_line = [
|
||||
line for line in system_prompt.splitlines() if "MUST include" in line
|
||||
][0]
|
||||
assert '"result"' in must_include_line
|
||||
assert '"error"' not in must_include_line
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# execute_block dry-run tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -279,7 +238,7 @@ async def test_execute_block_dry_run_skips_real_execution():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_response_format():
|
||||
"""Dry-run response should match real execution message format and have success=True."""
|
||||
"""Dry-run response should contain [DRY RUN] in message and success=True."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
@@ -300,8 +259,7 @@ async def test_execute_block_dry_run_response_format():
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert "executed successfully" in response.message
|
||||
assert "[DRY RUN]" not in response.message # must not leak to LLM context
|
||||
assert "[DRY RUN]" in response.message
|
||||
assert response.success is True
|
||||
assert response.outputs == {"result": ["simulated"]}
|
||||
|
||||
@@ -349,24 +307,23 @@ async def test_execute_block_real_execution_unchanged():
|
||||
|
||||
|
||||
def test_run_block_tool_dry_run_param():
|
||||
"""RunBlockTool parameters should include 'dry_run' as a required field."""
|
||||
"""RunBlockTool parameters should include 'dry_run'."""
|
||||
tool = RunBlockTool()
|
||||
params = tool.parameters
|
||||
assert "dry_run" in params["properties"]
|
||||
assert params["properties"]["dry_run"]["type"] == "boolean"
|
||||
assert "dry_run" in params["required"]
|
||||
|
||||
|
||||
def test_run_block_tool_dry_run_calls_execute():
|
||||
"""RunBlockTool._execute accepts dry_run as a typed parameter.
|
||||
"""RunBlockTool._execute extracts dry_run from kwargs correctly.
|
||||
|
||||
We verify the parameter exists in the signature and is forwarded to
|
||||
execute_block.
|
||||
We verify the extraction logic directly by inspecting the source, then confirm
|
||||
the kwarg is forwarded in the execute_block call site.
|
||||
"""
|
||||
source = inspect.getsource(run_block_module.RunBlockTool._execute)
|
||||
# Verify dry_run is a typed parameter (not extracted from kwargs)
|
||||
# Verify dry_run is extracted from kwargs
|
||||
assert "dry_run" in source
|
||||
assert "dry_run: bool" in source
|
||||
assert 'kwargs.get("dry_run"' in source
|
||||
|
||||
# Scope to _execute method source only — module-wide search is brittle
|
||||
# and can match unrelated text/comments.
|
||||
@@ -375,107 +332,13 @@ def test_run_block_tool_dry_run_calls_execute():
|
||||
assert "dry_run=dry_run" in source_execute
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_no_empty_error_from_simulator():
|
||||
"""The simulator no longer yields empty error pins, so execute_block
|
||||
simply passes through whatever the simulator produces.
|
||||
|
||||
Since the fix is at the simulator level, even if a simulator somehow
|
||||
yields only non-error outputs, they pass through unchanged.
|
||||
"""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
# Simulator now omits empty error pins at source
|
||||
yield "result", "simulated output"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert response.is_dry_run is True
|
||||
assert "error" not in response.outputs
|
||||
assert response.outputs == {"result": ["simulated output"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_keeps_nonempty_error_pin():
|
||||
"""Dry-run should keep the 'error' pin when it contains a real error message."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", ""
|
||||
yield "error", "API rate limit exceeded"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
# Non-empty error should be preserved
|
||||
assert "error" in response.outputs
|
||||
assert response.outputs["error"] == ["API rate limit exceeded"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_message_includes_completed_status():
|
||||
"""Dry-run message should clearly indicate COMPLETED status."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "hello"},
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
node_exec_id="node-exec-1",
|
||||
matched_credentials={},
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert "executed successfully" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_dry_run_simulator_error_returns_error_response():
|
||||
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate_error(block, input_data):
|
||||
yield (
|
||||
"error",
|
||||
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",
|
||||
)
|
||||
yield "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key)."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate_error
|
||||
|
||||
@@ -76,7 +76,6 @@ async def test_run_block_returns_details_when_no_input_provided():
|
||||
session=session,
|
||||
block_id="http-block-id",
|
||||
input_data={}, # Empty input data
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Should return BlockDetailsResponse showing the schema
|
||||
@@ -144,7 +143,6 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
session=session,
|
||||
block_id="api-block-id",
|
||||
input_data={"credentials": {"some": "cred"}}, # Only credential
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Should return details because no non-credential inputs provided
|
||||
|
||||
@@ -151,7 +151,7 @@ async def test_non_dict_tool_arguments_returns_error():
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="fetch",
|
||||
tool_arguments=["this", "is", "a", "list"], # type: ignore[arg-type] # intentionally wrong type to test validation
|
||||
tool_arguments=["this", "is", "a", "list"], # wrong type
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
|
||||
@@ -1,499 +0,0 @@
|
||||
"""Tests for session-level dry_run flag propagation.
|
||||
|
||||
Verifies that when a session has dry_run=True, run_block, run_agent, and
|
||||
run_mcp_tool calls are forced to use dry-run mode, regardless of what the
|
||||
individual tool call specifies. The single source of truth is
|
||||
``session.dry_run``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import ErrorResponse, MCPToolOutputResponse
|
||||
from backend.copilot.tools.run_agent import RunAgentInput, RunAgentTool
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.copilot.tools.run_mcp_tool import RunMCPToolTool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(dry_run: bool = False) -> ChatSession:
|
||||
"""Create a minimal ChatSession for testing."""
|
||||
session = ChatSession.new("test-user", dry_run=dry_run)
|
||||
return session
|
||||
|
||||
|
||||
def _make_mock_block(name: str = "TestBlock"):
|
||||
"""Create a minimal mock block with jsonschema() methods."""
|
||||
block = MagicMock()
|
||||
block.name = name
|
||||
block.description = "A test block"
|
||||
block.disabled = False
|
||||
block.block_type = "STANDARD"
|
||||
block.id = "test-block-id"
|
||||
|
||||
block.input_schema = MagicMock()
|
||||
block.input_schema.jsonschema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
}
|
||||
block.input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
block.output_schema = MagicMock()
|
||||
block.output_schema.jsonschema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {"result": {"type": "string"}},
|
||||
"required": ["result"],
|
||||
}
|
||||
|
||||
return block
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBlockToolSessionDryRun:
|
||||
"""Test that RunBlockTool respects session-level dry_run."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_dry_run_forces_block_dry_run(self):
|
||||
"""When session dry_run is True, run_block should force dry_run=True."""
|
||||
tool = RunBlockTool()
|
||||
session = _make_session(dry_run=True)
|
||||
|
||||
mock_block = _make_mock_block()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.prepare_block_for_execution"
|
||||
) as mock_prep,
|
||||
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_current_permissions",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# Set up prepare_block_for_execution to return a mock prep
|
||||
mock_prep_result = MagicMock()
|
||||
mock_prep_result.block = mock_block
|
||||
mock_prep_result.input_data = {"query": "test"}
|
||||
mock_prep_result.matched_credentials = {}
|
||||
mock_prep_result.synthetic_node_id = "node-1"
|
||||
mock_prep.return_value = mock_prep_result
|
||||
|
||||
# Set up execute_block to return a success
|
||||
mock_exec.return_value = MagicMock(
|
||||
message="Block 'TestBlock' executed successfully",
|
||||
success=True,
|
||||
)
|
||||
|
||||
await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "test"},
|
||||
dry_run=False, # User passed False, but session overrides
|
||||
)
|
||||
|
||||
# Verify execute_block was called with dry_run=True
|
||||
mock_exec.assert_called_once()
|
||||
call_kwargs = mock_exec.call_args
|
||||
assert call_kwargs.kwargs.get("dry_run") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_dry_run_respects_tool_param(self):
|
||||
"""When session dry_run is False, tool-level dry_run should be respected."""
|
||||
tool = RunBlockTool()
|
||||
session = _make_session(dry_run=False)
|
||||
|
||||
mock_block = _make_mock_block()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.prepare_block_for_execution"
|
||||
) as mock_prep,
|
||||
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_current_permissions",
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.copilot.tools.run_block.check_hitl_review") as mock_hitl,
|
||||
):
|
||||
mock_prep_result = MagicMock()
|
||||
mock_prep_result.block = mock_block
|
||||
mock_prep_result.input_data = {"query": "test"}
|
||||
mock_prep_result.matched_credentials = {}
|
||||
mock_prep_result.synthetic_node_id = "node-1"
|
||||
mock_prep_result.required_non_credential_keys = {"query"}
|
||||
mock_prep_result.provided_input_keys = {"query"}
|
||||
mock_prep.return_value = mock_prep_result
|
||||
|
||||
mock_hitl.return_value = ("node-exec-1", {"query": "test"})
|
||||
|
||||
mock_exec.return_value = MagicMock(
|
||||
message="Block executed",
|
||||
success=True,
|
||||
)
|
||||
|
||||
await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
block_id="test-block-id",
|
||||
input_data={"query": "test"},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Verify execute_block was called with dry_run=False
|
||||
mock_exec.assert_called_once()
|
||||
call_kwargs = mock_exec.call_args
|
||||
assert call_kwargs.kwargs.get("dry_run") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunAgentTool tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAgentToolSessionDryRun:
|
||||
"""Test that RunAgentTool respects session-level dry_run."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_dry_run_forces_agent_dry_run(self):
|
||||
"""When session dry_run is True, run_agent params.dry_run should be forced True."""
|
||||
tool = RunAgentTool()
|
||||
session = _make_session(dry_run=True)
|
||||
|
||||
# Mock the graph and dependencies
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.name = "Test Agent"
|
||||
mock_graph.description = "A test agent"
|
||||
mock_graph.input_schema = {"properties": {}, "required": []}
|
||||
mock_graph.trigger_setup_info = None
|
||||
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.id = "lib-1"
|
||||
mock_library_agent.graph_id = "graph-1"
|
||||
mock_library_agent.graph_version = 1
|
||||
mock_library_agent.name = "Test Agent"
|
||||
|
||||
mock_execution = MagicMock()
|
||||
mock_execution.id = "exec-1"
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.run_agent.graph_db"),
|
||||
patch("backend.copilot.tools.run_agent.library_db"),
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
return_value=(mock_graph, None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.match_user_credentials_to_graph",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.get_or_create_library_agent",
|
||||
return_value=mock_library_agent,
|
||||
),
|
||||
patch("backend.copilot.tools.run_agent.execution_utils") as mock_exec_utils,
|
||||
patch("backend.copilot.tools.run_agent.track_agent_run_success"),
|
||||
):
|
||||
mock_exec_utils.add_graph_execution = AsyncMock(return_value=mock_execution)
|
||||
|
||||
await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
username_agent_slug="user/test-agent",
|
||||
dry_run=False, # User passed False, but session overrides
|
||||
use_defaults=True,
|
||||
)
|
||||
|
||||
# Verify add_graph_execution was called with dry_run=True
|
||||
mock_exec_utils.add_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_exec_utils.add_graph_execution.call_args
|
||||
assert call_kwargs.kwargs.get("dry_run") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_dry_run_blocks_scheduling(self):
|
||||
"""When session dry_run is True, scheduling requests should be rejected."""
|
||||
tool = RunAgentTool()
|
||||
session = _make_session(dry_run=True)
|
||||
|
||||
result = await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
username_agent_slug="user/test-agent",
|
||||
schedule_name="daily-run",
|
||||
cron="0 9 * * *",
|
||||
dry_run=False, # Session overrides to True
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "dry-run" in result.message.lower()
|
||||
assert (
|
||||
"scheduling" in result.message.lower()
|
||||
or "schedule" in result.message.lower()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatSession model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatSessionDryRun:
|
||||
"""Test the dry_run field on ChatSession model."""
|
||||
|
||||
def test_new_session_default_dry_run_false(self):
|
||||
session = ChatSession.new("test-user", dry_run=False)
|
||||
assert session.dry_run is False
|
||||
|
||||
def test_new_session_dry_run_true(self):
|
||||
session = ChatSession.new("test-user", dry_run=True)
|
||||
assert session.dry_run is True
|
||||
|
||||
def test_new_session_dry_run_false_explicit(self):
|
||||
session = ChatSession.new("test-user", dry_run=False)
|
||||
assert session.dry_run is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunAgentInput tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAgentInputDryRunOverride:
|
||||
"""Test that RunAgentInput.dry_run can be mutated by session-level override."""
|
||||
|
||||
def test_explicit_dry_run_false(self):
|
||||
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
|
||||
assert params.dry_run is False
|
||||
|
||||
def test_session_override(self):
|
||||
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
|
||||
# Simulate session-level override
|
||||
params.dry_run = True
|
||||
assert params.dry_run is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunMCPToolTool tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMCPToolToolSessionDryRun:
|
||||
"""Test that RunMCPToolTool respects session-level dry_run."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_dry_run_blocks_mcp_execution(self):
|
||||
"""When session dry_run is True, MCP tool execution should be skipped."""
|
||||
tool = RunMCPToolTool()
|
||||
session = _make_session(dry_run=True)
|
||||
|
||||
result = await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/sse",
|
||||
tool_name="some_tool",
|
||||
tool_arguments={"key": "value"},
|
||||
)
|
||||
|
||||
assert isinstance(result, MCPToolOutputResponse)
|
||||
assert result.success is True
|
||||
assert "dry-run" in result.message
|
||||
assert result.tool_name == "some_tool"
|
||||
assert result.result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_dry_run_allows_discovery(self):
|
||||
"""When session dry_run is True, tool discovery (no tool_name) should still work."""
|
||||
tool = RunMCPToolTool()
|
||||
session = _make_session(dry_run=True)
|
||||
|
||||
# Discovery requires a network call, so we mock the client
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "A test tool"
|
||||
mock_tool.input_schema = {"type": "object", "properties": {}}
|
||||
mock_client.list_tools.return_value = [mock_tool]
|
||||
|
||||
result = await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/sse",
|
||||
tool_name="", # Discovery mode
|
||||
)
|
||||
|
||||
# Discovery should proceed normally
|
||||
mock_client.initialize.assert_called_once()
|
||||
mock_client.list_tools.assert_called_once()
|
||||
assert "Discovered" in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_dry_run_allows_execution(self):
|
||||
"""When session dry_run is False, MCP tool execution should proceed."""
|
||||
tool = RunMCPToolTool()
|
||||
session = _make_session(dry_run=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.is_error = False
|
||||
mock_result.content = [{"type": "text", "text": "hello"}]
|
||||
mock_client.call_tool.return_value = mock_result
|
||||
|
||||
result = await tool._execute(
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/sse",
|
||||
tool_name="some_tool",
|
||||
tool_arguments={"key": "value"},
|
||||
)
|
||||
|
||||
# Execution should proceed
|
||||
mock_client.initialize.assert_called_once()
|
||||
mock_client.call_tool.assert_called_once_with("some_tool", {"key": "value"})
|
||||
assert isinstance(result, MCPToolOutputResponse)
|
||||
assert result.success is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatibility tests for ChatSessionMetadata deserialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatSessionMetadataBackwardCompat:
|
||||
"""Verify that sessions created before the dry_run field existed still load.
|
||||
|
||||
The ``metadata`` JSON column in the DB may contain ``{}``, ``null``, or a
|
||||
dict without the ``dry_run`` key for sessions created before the flag was
|
||||
introduced. These must deserialize without errors and default to
|
||||
``dry_run=False``.
|
||||
"""
|
||||
|
||||
def test_metadata_default_construction(self):
|
||||
"""ChatSessionMetadata() with no args should default dry_run=False."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
meta = ChatSessionMetadata()
|
||||
assert meta.dry_run is False
|
||||
|
||||
def test_metadata_from_empty_dict(self):
|
||||
"""Deserializing an empty dict (old-format metadata) should succeed."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
meta = ChatSessionMetadata.model_validate({})
|
||||
assert meta.dry_run is False
|
||||
|
||||
def test_metadata_from_dict_without_dry_run_key(self):
|
||||
"""A metadata dict with other keys but no dry_run should still work."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
meta = ChatSessionMetadata.model_validate({"some_future_field": 42})
|
||||
# dry_run should fall back to default
|
||||
assert meta.dry_run is False
|
||||
|
||||
def test_metadata_round_trip_with_dry_run_false(self):
|
||||
"""Serialize then deserialize with dry_run=False."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
original = ChatSessionMetadata(dry_run=False)
|
||||
raw = original.model_dump()
|
||||
restored = ChatSessionMetadata.model_validate(raw)
|
||||
assert restored.dry_run is False
|
||||
|
||||
def test_metadata_round_trip_with_dry_run_true(self):
|
||||
"""Serialize then deserialize with dry_run=True."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
original = ChatSessionMetadata(dry_run=True)
|
||||
raw = original.model_dump()
|
||||
restored = ChatSessionMetadata.model_validate(raw)
|
||||
assert restored.dry_run is True
|
||||
|
||||
def test_metadata_json_round_trip(self):
|
||||
"""Serialize to JSON string and back, simulating Redis cache flow."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
original = ChatSessionMetadata(dry_run=True)
|
||||
json_str = original.model_dump_json()
|
||||
restored = ChatSessionMetadata.model_validate_json(json_str)
|
||||
assert restored.dry_run is True
|
||||
|
||||
def test_session_dry_run_property_with_default_metadata(self):
|
||||
"""ChatSession.dry_run returns False when metadata has no dry_run."""
|
||||
from backend.copilot.model import ChatSessionMetadata
|
||||
|
||||
# Simulate building a session with metadata deserialized from an old row
|
||||
meta = ChatSessionMetadata.model_validate({})
|
||||
session = _make_session(dry_run=False)
|
||||
session.metadata = meta
|
||||
assert session.dry_run is False
|
||||
|
||||
def test_session_info_dry_run_property_with_default_metadata(self):
|
||||
"""ChatSessionInfo.dry_run returns False when metadata is default."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
info = ChatSessionInfo(
|
||||
session_id="old-session-id",
|
||||
user_id="test-user",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metadata=ChatSessionMetadata.model_validate({}),
|
||||
)
|
||||
assert info.dry_run is False
|
||||
|
||||
def test_session_full_json_round_trip_without_dry_run(self):
|
||||
"""A full ChatSession JSON round-trip preserves dry_run default."""
|
||||
session = _make_session(dry_run=False)
|
||||
json_bytes = session.model_dump_json()
|
||||
restored = ChatSession.model_validate_json(json_bytes)
|
||||
assert restored.dry_run is False
|
||||
assert restored.metadata.dry_run is False
|
||||
|
||||
def test_session_full_json_round_trip_with_dry_run(self):
|
||||
"""A full ChatSession JSON round-trip preserves dry_run=True."""
|
||||
session = _make_session(dry_run=True)
|
||||
json_bytes = session.model_dump_json()
|
||||
restored = ChatSession.model_validate_json(json_bytes)
|
||||
assert restored.dry_run is True
|
||||
assert restored.metadata.dry_run is True
|
||||
@@ -48,9 +48,9 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
agent_json: dict | None = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
|
||||
@@ -87,11 +87,10 @@ class WebFetchTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
url: str = "",
|
||||
extract_text: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
url = url.strip()
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
extract_text: bool = kwargs.get("extract_text", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not url:
|
||||
|
||||
@@ -450,9 +450,6 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
path_prefix: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include_all_sessions: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
@@ -461,7 +458,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
limit = min(limit, 100)
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
@@ -568,12 +567,6 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
file_id: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
save_to_path: Optional[str] = None,
|
||||
force_download_url: bool = False,
|
||||
offset: int = 0,
|
||||
length: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
@@ -582,8 +575,12 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
char_offset: int = max(0, offset)
|
||||
char_length: Optional[int] = length
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
char_offset: int = max(0, kwargs.get("offset", 0))
|
||||
char_length: Optional[int] = kwargs.get("length")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
@@ -773,13 +770,6 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
filename: str = "",
|
||||
source_path: str | None = None,
|
||||
content: str | None = None,
|
||||
content_base64: str | None = None,
|
||||
path: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
@@ -788,36 +778,15 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
if not filename:
|
||||
# When ALL parameters are missing, the most likely cause is
|
||||
# output token truncation: the LLM tried to inline a very large
|
||||
# file as `content`, the SDK silently truncated the tool call
|
||||
# arguments to `{}`, and we receive nothing. Return an
|
||||
# actionable error instead of a generic "filename required".
|
||||
has_any_content = any(
|
||||
kwargs.get(k) for k in ("content", "content_base64", "source_path")
|
||||
)
|
||||
if not has_any_content:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Tool call appears truncated (no arguments received). "
|
||||
"This happens when the content is too large for a "
|
||||
"single tool call. Instead of passing content inline, "
|
||||
"first write the file to the working directory using "
|
||||
"bash_exec (e.g. cat > /home/user/file.md << 'EOF'... "
|
||||
"EOF), then use source_path to copy it to workspace: "
|
||||
"write_workspace_file(filename='file.md', "
|
||||
"source_path='/home/user/file.md')"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename", session_id=session_id
|
||||
)
|
||||
|
||||
source_path_arg: str | None = source_path
|
||||
content_text: str | None = content
|
||||
content_b64: str | None = content_base64
|
||||
source_path_arg: str | None = kwargs.get("source_path")
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
resolved = await _resolve_write_content(
|
||||
content_text,
|
||||
@@ -827,24 +796,24 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
content_bytes: bytes = resolved
|
||||
content: bytes = resolved
|
||||
|
||||
max_size = _MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
if len(content_bytes) > max_size:
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {_MAX_FILE_SIZE_MB}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await scan_content_safe(content_bytes, filename=filename)
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content_bytes,
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
path=kwargs.get("path"),
|
||||
mime_type=kwargs.get("mime_type"),
|
||||
overwrite=kwargs.get("overwrite", False),
|
||||
)
|
||||
|
||||
# Build informative source label and message.
|
||||
@@ -868,8 +837,8 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
preview: str | None = None
|
||||
if _is_text_mime(rec.mime_type):
|
||||
try:
|
||||
preview = content_bytes[:200].decode("utf-8", errors="replace")
|
||||
if len(content_bytes) > 200:
|
||||
preview = content[:200].decode("utf-8", errors="replace")
|
||||
if len(content) > 200:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass
|
||||
@@ -941,8 +910,6 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
file_id: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
@@ -950,6 +917,9 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
|
||||
@@ -13,7 +13,7 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from backend.util.clients import get_openai_client
|
||||
@@ -96,10 +96,6 @@ def build_simulation_prompt(block: Any, input_data: dict[str, Any]) -> tuple[str
|
||||
input_pins = _describe_schema_pins(input_schema)
|
||||
output_pins = _describe_schema_pins(output_schema)
|
||||
output_properties = list(output_schema.get("properties", {}).keys())
|
||||
# Build a separate list for the "MUST include" instruction that excludes
|
||||
# "error" — the prompt already tells the LLM to OMIT the error pin unless
|
||||
# simulating a logical error. Including it in "MUST include" is contradictory.
|
||||
required_output_properties = [k for k in output_properties if k != "error"]
|
||||
|
||||
block_name = getattr(block, "name", type(block).__name__)
|
||||
block_description = getattr(block, "description", "No description available.")
|
||||
@@ -121,10 +117,10 @@ Rules:
|
||||
- Respond with a single JSON object whose keys are EXACTLY the output pin names listed above.
|
||||
- Assume all credentials and authentication are present and valid. Never simulate authentication failures.
|
||||
- Make the simulated outputs realistic and consistent with the inputs.
|
||||
- If there is an "error" pin, OMIT it entirely unless you are simulating a logical error. Only include the "error" pin when there is a genuine error message to report.
|
||||
- If there is an "error" pin, set it to "" (empty string) unless you are simulating a logical error.
|
||||
- Do not include any extra keys beyond the output pins.
|
||||
|
||||
Output pin names you MUST include: {json.dumps(required_output_properties)}
|
||||
Output pin names you MUST include: {json.dumps(output_properties)}
|
||||
"""
|
||||
|
||||
safe_inputs = _truncate_input_values(input_data)
|
||||
@@ -136,7 +132,7 @@ Output pin names you MUST include: {json.dumps(required_output_properties)}
|
||||
async def simulate_block(
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
"""Simulate block execution using an LLM.
|
||||
|
||||
Yields (output_name, output_data) tuples matching the Block.execute() interface.
|
||||
@@ -176,26 +172,13 @@ async def simulate_block(
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
|
||||
|
||||
# Fill missing output pins with defaults.
|
||||
# Skip empty "error" pins — an empty string means "no error" and
|
||||
# would only confuse downstream consumers (LLM, frontend).
|
||||
# Fill missing output pins with defaults
|
||||
result: dict[str, Any] = {}
|
||||
for pin_name in output_properties:
|
||||
if pin_name in parsed:
|
||||
value = parsed[pin_name]
|
||||
# Drop empty/blank error pins: they carry no information.
|
||||
# Uses strip() intentionally so whitespace-only strings
|
||||
# (e.g. " ", "\n") are also treated as empty.
|
||||
if (
|
||||
pin_name == "error"
|
||||
and isinstance(value, str)
|
||||
and not value.strip()
|
||||
):
|
||||
continue
|
||||
result[pin_name] = value
|
||||
elif pin_name != "error":
|
||||
# Only fill non-error missing pins with None
|
||||
result[pin_name] = None
|
||||
result[pin_name] = parsed[pin_name]
|
||||
else:
|
||||
result[pin_name] = "" if pin_name == "error" else None
|
||||
|
||||
logger.debug(
|
||||
"simulate_block: block=%s attempt=%d tokens=%s/%s",
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import thread_cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -305,12 +304,15 @@ def is_system_provider(provider: str) -> bool:
|
||||
|
||||
|
||||
class IntegrationCredentialsStore:
|
||||
@thread_cached
|
||||
def __init__(self):
|
||||
self._locks = None
|
||||
|
||||
async def locks(self) -> AsyncRedisKeyedMutex:
|
||||
# Per-thread: copilot executor runs worker threads with separate event
|
||||
# loops; AsyncRedisKeyedMutex's internal asyncio.Lock is bound to the
|
||||
# loop it was created on.
|
||||
return AsyncRedisKeyedMutex(await get_redis_async())
|
||||
if self._locks:
|
||||
return self._locks
|
||||
|
||||
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
|
||||
return self._locks
|
||||
|
||||
@property
|
||||
def db_manager(self):
|
||||
|
||||
@@ -8,6 +8,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import (
|
||||
IntegrationCredentialsStore,
|
||||
provider_matches,
|
||||
@@ -105,13 +106,14 @@ class IntegrationCredentialsManager:
|
||||
|
||||
def __init__(self):
|
||||
self.store = IntegrationCredentialsStore()
|
||||
self._locks = None
|
||||
|
||||
async def locks(self) -> AsyncRedisKeyedMutex:
|
||||
# Delegate to store's @thread_cached locks. Manager uses these for
|
||||
# fine-grained per-credential locking (refresh, acquire); the store
|
||||
# uses its own for coarse per-user integrations locking. Same mutex
|
||||
# type, different key spaces — no collision.
|
||||
return await self.store.locks()
|
||||
if self._locks:
|
||||
return self._locks
|
||||
|
||||
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
@@ -186,74 +188,35 @@ class IntegrationCredentialsManager:
|
||||
|
||||
async def refresh_if_needed(
|
||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||
) -> OAuth2Credentials:
|
||||
# When lock=False, skip ALL Redis locking (both the outer "refresh" scope
|
||||
# lock and the inner credential lock). This is used by the copilot's
|
||||
# integration_creds module which runs across multiple threads with separate
|
||||
# event loops; acquiring a Redis lock whose asyncio.Lock() was created on
|
||||
# a different loop raises "Future attached to a different loop".
|
||||
if lock:
|
||||
return await self._refresh_locked(user_id, credentials)
|
||||
return await self._refresh_unlocked(user_id, credentials)
|
||||
|
||||
async def _get_oauth_handler(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Resolve the appropriate OAuth handler for the given credentials."""
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
return create_mcp_oauth_handler(credentials)
|
||||
return await _get_provider_oauth_handler(credentials.provider)
|
||||
|
||||
async def _refresh_locked(
|
||||
self, user_id: str, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
async with self._locked(user_id, credentials.id, "refresh"):
|
||||
oauth_handler = await self._get_oauth_handler(credentials)
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
oauth_handler = create_mcp_oauth_handler(credentials)
|
||||
else:
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
"Refreshing '%s' credentials #%s",
|
||||
credentials.provider,
|
||||
credentials.id,
|
||||
f"Refreshing '{credentials.provider}' credentials #{credentials.id}"
|
||||
)
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = await self._acquire_lock(user_id, credentials.id)
|
||||
try:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
credentials = fresh_credentials
|
||||
finally:
|
||||
if (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release OAuth refresh lock",
|
||||
exc_info=True,
|
||||
)
|
||||
return credentials
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = await self._acquire_lock(user_id, credentials.id)
|
||||
|
||||
async def _refresh_unlocked(
|
||||
self, user_id: str, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Best-effort token refresh without any Redis locking.
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Notify listeners so the refreshed token is picked up immediately.
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release OAuth refresh lock",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
Safe for use from multi-threaded contexts (e.g. copilot workers) where
|
||||
each thread has its own event loop and sharing Redis-backed asyncio locks
|
||||
is not possible. Concurrent refreshes are tolerated: the last writer
|
||||
wins, and stale tokens are overwritten.
|
||||
"""
|
||||
oauth_handler = await self._get_oauth_handler(credentials)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
"Refreshing '%s' credentials #%s (lock-free)",
|
||||
credentials.provider,
|
||||
credentials.id,
|
||||
)
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
credentials = fresh_credentials
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
@@ -301,6 +264,7 @@ class IntegrationCredentialsManager:
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
await (await self.locks()).release_all_locks()
|
||||
await (await self.store.locks()).release_all_locks()
|
||||
|
||||
|
||||
|
||||
@@ -251,50 +251,6 @@ def estimate_token_count_str(
|
||||
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||
DEFAULT_KEEP_RECENT = 15
|
||||
|
||||
# Reserve tokens for system prompt, tool definitions, and per-turn overhead.
|
||||
# The actual model context limit minus this reserve = compression target.
|
||||
_CONTEXT_OVERHEAD_RESERVE = 60_000
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int | None:
|
||||
"""Return the context window size for a model, or None if unknown.
|
||||
|
||||
Looks up the model in the :class:`LlmModel` enum (which already
|
||||
carries ``context_window`` via ``MODEL_METADATA``). Handles
|
||||
provider-prefixed names (``anthropic/claude-opus-4-6``) and
|
||||
case-insensitive input automatically.
|
||||
"""
|
||||
from backend.blocks.llm import LlmModel # lazy to avoid circular import
|
||||
|
||||
try:
|
||||
llm_model = LlmModel(model)
|
||||
return llm_model.context_window
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
# Retry with lowercase for case-insensitive lookup
|
||||
try:
|
||||
llm_model = LlmModel(model.lower())
|
||||
return llm_model.context_window
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def get_compression_target(model: str) -> int:
|
||||
"""Compute a model-aware compression target for conversation history.
|
||||
|
||||
Returns ``context_window - overhead_reserve``, floored at 10K.
|
||||
Falls back to ``DEFAULT_TOKEN_THRESHOLD`` for unknown models or
|
||||
models whose context window is too small for the overhead reserve.
|
||||
"""
|
||||
window = get_context_window(model)
|
||||
if window is None:
|
||||
return DEFAULT_TOKEN_THRESHOLD
|
||||
target = window - _CONTEXT_OVERHEAD_RESERVE
|
||||
if target < 10_000:
|
||||
return DEFAULT_TOKEN_THRESHOLD
|
||||
return target
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressResult:
|
||||
@@ -704,7 +660,7 @@ async def _summarize_messages_llm(
|
||||
|
||||
async def compress_context(
|
||||
messages: list[dict],
|
||||
target_tokens: int | None = None,
|
||||
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
client: AsyncOpenAI | None = None,
|
||||
@@ -716,11 +672,6 @@ async def compress_context(
|
||||
"""
|
||||
Unified context compression that combines summarization and truncation strategies.
|
||||
|
||||
When ``target_tokens`` is None (the default), it is computed from the
|
||||
model's context window via ``get_compression_target(model)``. This
|
||||
ensures large-context models (e.g. Opus 200K) retain more history
|
||||
while smaller models compress more aggressively.
|
||||
|
||||
Strategy (in order):
|
||||
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||
single context message while keeping recent messages intact. This is the
|
||||
@@ -748,10 +699,6 @@ async def compress_context(
|
||||
-------
|
||||
CompressResult with compressed messages and metadata.
|
||||
"""
|
||||
# Resolve model-aware target when caller doesn't specify an explicit limit.
|
||||
if target_tokens is None:
|
||||
target_tokens = get_compression_target(model)
|
||||
|
||||
# Guard clause for empty messages
|
||||
if not messages:
|
||||
return CompressResult(
|
||||
|
||||
@@ -7,7 +7,6 @@ from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import (
|
||||
DEFAULT_TOKEN_THRESHOLD,
|
||||
CompressResult,
|
||||
_ensure_tool_pairs_intact,
|
||||
_msg_tokens,
|
||||
@@ -16,8 +15,6 @@ from backend.util.prompt import (
|
||||
_truncate_tool_message_content,
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
get_compression_target,
|
||||
get_context_window,
|
||||
)
|
||||
|
||||
|
||||
@@ -977,43 +974,3 @@ class TestCompressResultDataclass:
|
||||
assert result.original_token_count == 500
|
||||
assert result.messages_summarized == 10
|
||||
assert result.messages_dropped == 5
|
||||
|
||||
|
||||
class TestGetContextWindow:
|
||||
def test_claude_opus(self) -> None:
|
||||
assert get_context_window("claude-opus-4-20250514") == 200_000
|
||||
|
||||
def test_claude_sonnet(self) -> None:
|
||||
assert get_context_window("claude-sonnet-4-20250514") == 200_000
|
||||
|
||||
def test_openrouter_prefix(self) -> None:
|
||||
assert get_context_window("anthropic/claude-opus-4-6") == 200_000
|
||||
|
||||
def test_version_suffix(self) -> None:
|
||||
assert get_context_window("claude-opus-4-6") == 200_000
|
||||
|
||||
def test_gpt4o(self) -> None:
|
||||
assert get_context_window("gpt-4o") == 128_000
|
||||
|
||||
def test_unknown_model(self) -> None:
|
||||
assert get_context_window("some-unknown-model") is None
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert get_context_window("GPT-4o") == 128_000
|
||||
|
||||
|
||||
class TestGetCompressionTarget:
|
||||
def test_claude_opus_200k(self) -> None:
|
||||
target = get_compression_target("anthropic/claude-opus-4-6")
|
||||
assert target == 140_000 # 200K - 60K overhead
|
||||
|
||||
def test_gpt4o_128k(self) -> None:
|
||||
target = get_compression_target("gpt-4o")
|
||||
assert target == 68_000 # 128K - 60K overhead
|
||||
|
||||
def test_unknown_model_returns_default(self) -> None:
|
||||
assert get_compression_target("unknown-model") == DEFAULT_TOKEN_THRESHOLD
|
||||
|
||||
def test_small_model_returns_default(self) -> None:
|
||||
# Unknown models fall back to DEFAULT_TOKEN_THRESHOLD
|
||||
assert get_compression_target("some-tiny-model") == DEFAULT_TOKEN_THRESHOLD
|
||||
|
||||
@@ -22,6 +22,7 @@ function generateTestGraph(name = null) {
|
||||
input_default: {
|
||||
name: "Load Test Input",
|
||||
description: "Test input for load testing",
|
||||
placeholder_values: {},
|
||||
},
|
||||
input_nodes: [],
|
||||
output_nodes: ["output_node"],
|
||||
@@ -58,7 +59,11 @@ function generateExecutionInputs() {
|
||||
"Load Test Input": {
|
||||
name: "Load Test Input",
|
||||
description: "Test input for load testing",
|
||||
value: `Test execution at ${new Date().toISOString()}`,
|
||||
placeholder_values: {
|
||||
test_data: `Test execution at ${new Date().toISOString()}`,
|
||||
test_parameter: Math.random().toString(36).substr(2, 9),
|
||||
numeric_value: Math.floor(Math.random() * 1000),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
-- Add extensible metadata JSONB column to ChatSession.
|
||||
-- New session-level flags (e.g. dry_run) live inside this JSON
|
||||
-- so future additions need no extra migrations.
|
||||
ALTER TABLE "ChatSession" ADD COLUMN "metadata" JSONB NOT NULL DEFAULT '{}';
|
||||
@@ -220,10 +220,6 @@ model ChatSession {
|
||||
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||
successfulAgentSchedules Json @default("{}") // Map of graph_id -> count
|
||||
|
||||
// Extensible session metadata (typed via ChatSessionMetadata in Python).
|
||||
// Avoids DB migrations for each new flag (e.g. dry_run, future fields).
|
||||
metadata Json @default("{}")
|
||||
|
||||
// Usage tracking
|
||||
totalPromptTokens Int @default(0)
|
||||
totalCompletionTokens Int @default(0)
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Download CoPilot transcripts from prod GCS and load into local dev environment.
|
||||
|
||||
Usage:
|
||||
# Step 1: Download from prod GCS (needs MEDIA_GCS_BUCKET_NAME + gcloud auth)
|
||||
MEDIA_GCS_BUCKET_NAME=<prod-bucket> USER_ID=<user-uuid> \
|
||||
poetry run python scripts/download_transcripts.py download <session_id> ...
|
||||
|
||||
# Step 2: Load downloaded transcripts into local storage + DB
|
||||
poetry run python scripts/download_transcripts.py load <session_id> ...
|
||||
|
||||
# Or do both in one step (if you have GCS access):
|
||||
MEDIA_GCS_BUCKET_NAME=<prod-bucket> USER_ID=<user-uuid> \
|
||||
poetry run python scripts/download_transcripts.py both <session_id> ...
|
||||
|
||||
The "download" step saves transcripts to transcripts/<session_id>.jsonl.
|
||||
The "load" step reads those files and:
|
||||
1. Creates a ChatSession in local DB (or reuses existing)
|
||||
2. Populates messages from the transcript
|
||||
3. Stores transcript in local workspace storage
|
||||
4. Creates metadata so --resume works on the next turn
|
||||
|
||||
After "load", you can send a message to the session via the CoPilot UI
|
||||
and it will use --resume with the loaded transcript.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
_SAFE_RE = re.compile(r"[^0-9a-fA-F-]")
|
||||
TRANSCRIPTS_DIR = os.path.join(os.path.dirname(__file__), "..", "transcripts")
|
||||
|
||||
|
||||
def _sanitize(raw: str) -> str:
|
||||
cleaned = _SAFE_RE.sub("", raw or "")[:36]
|
||||
if not cleaned:
|
||||
raise ValueError(f"Invalid ID: {raw!r}")
|
||||
return cleaned
|
||||
|
||||
|
||||
def _transcript_path(session_id: str) -> str:
|
||||
return os.path.join(TRANSCRIPTS_DIR, f"{_sanitize(session_id)}.jsonl")
|
||||
|
||||
|
||||
def _meta_path(session_id: str) -> str:
|
||||
return os.path.join(TRANSCRIPTS_DIR, f"{_sanitize(session_id)}.meta.json")
|
||||
|
||||
|
||||
# ── Download from GCS ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def cmd_download(session_ids: list[str]) -> None:
|
||||
"""Download transcripts from prod GCS to transcripts/ directory."""
|
||||
from backend.copilot.sdk.transcript import download_transcript
|
||||
|
||||
user_id = os.environ.get("USER_ID", "")
|
||||
if not user_id:
|
||||
print("ERROR: Set USER_ID env var to the session owner's user ID.")
|
||||
print(" You can find it in Sentry breadcrumbs or the DB.")
|
||||
sys.exit(1)
|
||||
|
||||
bucket = os.environ.get("MEDIA_GCS_BUCKET_NAME", "")
|
||||
if not bucket:
|
||||
print("ERROR: Set MEDIA_GCS_BUCKET_NAME to the prod GCS bucket.")
|
||||
sys.exit(1)
|
||||
|
||||
os.makedirs(TRANSCRIPTS_DIR, exist_ok=True)
|
||||
print(f"Downloading from GCS bucket: {bucket}")
|
||||
print(f"User ID: {user_id}\n")
|
||||
|
||||
for sid in session_ids:
|
||||
print(f"[{sid[:12]}] Downloading...")
|
||||
try:
|
||||
dl = await download_transcript(user_id, sid)
|
||||
except Exception as e:
|
||||
print(f"[{sid[:12]}] Failed: {e}")
|
||||
continue
|
||||
|
||||
if not dl or not dl.content:
|
||||
print(f"[{sid[:12]}] Not found in GCS")
|
||||
continue
|
||||
|
||||
out = _transcript_path(sid)
|
||||
with open(out, "w") as f:
|
||||
f.write(dl.content)
|
||||
|
||||
lines = len(dl.content.strip().split("\n"))
|
||||
meta = {
|
||||
"session_id": sid,
|
||||
"user_id": user_id,
|
||||
"message_count": dl.message_count,
|
||||
"uploaded_at": dl.uploaded_at,
|
||||
"transcript_bytes": len(dl.content),
|
||||
"transcript_lines": lines,
|
||||
}
|
||||
with open(_meta_path(sid), "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(
|
||||
f"[{sid[:12]}] Saved: {lines} entries, "
|
||||
f"{len(dl.content)} bytes, msg_count={dl.message_count}"
|
||||
)
|
||||
print("\nDone. Run 'load' command to import into local dev environment.")
|
||||
|
||||
|
||||
# ── Load into local dev ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_messages_from_transcript(content: str) -> list[dict]:
|
||||
"""Extract user/assistant messages from JSONL transcript for DB seeding."""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role")
|
||||
if role not in ("user", "assistant"):
|
||||
continue
|
||||
|
||||
content_blocks = msg.get("content", "")
|
||||
if isinstance(content_blocks, list):
|
||||
# Flatten content blocks to text
|
||||
text_parts = []
|
||||
for block in content_blocks:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
text = "\n".join(text_parts)
|
||||
elif isinstance(content_blocks, str):
|
||||
text = content_blocks
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if text:
|
||||
messages.append({"role": role, "content": text})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def cmd_load(session_ids: list[str]) -> None:
|
||||
"""Load downloaded transcripts into local workspace storage + DB."""
|
||||
from backend.copilot.sdk.transcript import upload_transcript
|
||||
|
||||
# Use the user_id from meta file or env var
|
||||
default_user_id = os.environ.get("USER_ID", "")
|
||||
|
||||
for sid in session_ids:
|
||||
transcript_file = _transcript_path(sid)
|
||||
meta_file = _meta_path(sid)
|
||||
|
||||
if not os.path.exists(transcript_file):
|
||||
print(f"[{sid[:12]}] No transcript file at {transcript_file}")
|
||||
print(" Run 'download' first, or place the file manually.")
|
||||
continue
|
||||
|
||||
with open(transcript_file) as f:
|
||||
content = f.read()
|
||||
|
||||
# Load meta if available
|
||||
user_id = default_user_id
|
||||
msg_count = 0
|
||||
if os.path.exists(meta_file):
|
||||
with open(meta_file) as f:
|
||||
meta = json.load(f)
|
||||
user_id = meta.get("user_id", user_id)
|
||||
msg_count = meta.get("message_count", 0)
|
||||
|
||||
if not user_id:
|
||||
print(f"[{sid[:12]}] No user_id — set USER_ID env var or download first")
|
||||
continue
|
||||
|
||||
lines = len(content.strip().split("\n"))
|
||||
print(f"[{sid[:12]}] Loading transcript: {lines} entries, {len(content)} bytes")
|
||||
|
||||
# Parse messages from transcript for DB
|
||||
messages = _parse_messages_from_transcript(content)
|
||||
if not msg_count:
|
||||
msg_count = len(messages)
|
||||
print(f"[{sid[:12]}] Parsed {len(messages)} messages for DB")
|
||||
|
||||
# Create chat session in DB
|
||||
try:
|
||||
from backend.copilot.db import create_chat_session, get_chat_session
|
||||
|
||||
existing = await get_chat_session(sid)
|
||||
if existing:
|
||||
print(f"[{sid[:12]}] Session already exists in DB, skipping creation")
|
||||
else:
|
||||
await create_chat_session(sid, user_id)
|
||||
print(f"[{sid[:12]}] Created ChatSession in DB")
|
||||
except Exception as e:
|
||||
print(f"[{sid[:12]}] DB session creation failed: {e}")
|
||||
print(" You may need to create it manually or run with DB access.")
|
||||
|
||||
# Add messages to DB
|
||||
if messages:
|
||||
try:
|
||||
from backend.copilot.db import add_chat_messages_batch
|
||||
|
||||
msg_dicts = [
|
||||
{"role": m["role"], "content": m["content"]} for m in messages
|
||||
]
|
||||
await add_chat_messages_batch(sid, msg_dicts, start_sequence=0)
|
||||
print(f"[{sid[:12]}] Added {len(messages)} messages to DB")
|
||||
except Exception as e:
|
||||
print(f"[{sid[:12]}] Message insertion failed: {e}")
|
||||
print(" (Session may already have messages)")
|
||||
|
||||
# Store transcript in local workspace storage
|
||||
try:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=sid,
|
||||
content=content,
|
||||
message_count=msg_count,
|
||||
)
|
||||
print(f"[{sid[:12]}] Stored transcript in local workspace storage")
|
||||
except Exception as e:
|
||||
print(f"[{sid[:12]}] Transcript storage failed: {e}")
|
||||
|
||||
# Also store directly to filesystem as fallback
|
||||
try:
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
storage_dir = settings.config.workspace_storage_dir or os.path.join(
|
||||
os.path.expanduser("~"), ".autogpt", "workspaces"
|
||||
)
|
||||
ts_dir = os.path.join(storage_dir, "chat-transcripts", _sanitize(user_id))
|
||||
os.makedirs(ts_dir, exist_ok=True)
|
||||
|
||||
ts_path = os.path.join(ts_dir, f"{_sanitize(sid)}.jsonl")
|
||||
with open(ts_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
meta_storage = {
|
||||
"message_count": msg_count,
|
||||
"uploaded_at": time.time(),
|
||||
}
|
||||
meta_storage_path = os.path.join(ts_dir, f"{_sanitize(sid)}.meta.json")
|
||||
with open(meta_storage_path, "w") as f:
|
||||
json.dump(meta_storage, f)
|
||||
|
||||
print(f"[{sid[:12]}] Also wrote to: {ts_path}")
|
||||
except Exception as e:
|
||||
print(f"[{sid[:12]}] Direct file write failed: {e}")
|
||||
|
||||
print(f"[{sid[:12]}] Ready — send a message to this session to test")
|
||||
print()
|
||||
|
||||
print("Done. Start the backend and send a message to the session(s).")
|
||||
print("The CoPilot will use --resume with the loaded transcript.")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
command = sys.argv[1]
|
||||
session_ids = sys.argv[2:]
|
||||
|
||||
if command == "download":
|
||||
await cmd_download(session_ids)
|
||||
elif command == "load":
|
||||
await cmd_load(session_ids)
|
||||
elif command == "both":
|
||||
await cmd_download(session_ids)
|
||||
print("\n" + "=" * 60 + "\n")
|
||||
await cmd_load(session_ids)
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("Usage: download | load | both")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,93 +0,0 @@
|
||||
# Frontend
|
||||
|
||||
This file provides guidance to coding agents when working with the frontend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
### Pre-completion Checks (MANDATORY)
|
||||
|
||||
After making **any** code changes in the frontend, you MUST run the following commands **in order** before reporting work as done, creating commits, or opening PRs:
|
||||
|
||||
1. `pnpm format` — auto-fix formatting issues
|
||||
2. `pnpm lint` — check for lint errors; fix any that appear
|
||||
3. `pnpm types` — check for type errors; fix any that appear
|
||||
|
||||
Do NOT skip these steps. If any command reports errors, fix them and re-run until clean. Only then may you consider the task complete. If typing keeps failing, stop and ask the user.
|
||||
|
||||
### Code Style
|
||||
|
||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- No `dark:` Tailwind classes — the design system handles dark mode
|
||||
- Use Next.js `<Link>` for internal navigation — never raw `<a>` tags
|
||||
- No `any` types unless the value genuinely can be anything
|
||||
- No linter suppressors (`// @ts-ignore`, `// eslint-disable`) — fix the actual issue
|
||||
- **File length** — keep files under ~200 lines; extract sub-components or hooks into their own files when a file grows beyond this
|
||||
- **Function/component length** — keep render functions and hooks under ~50 lines; extract named helpers or sub-components when they grow longer
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
`.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Feature Development
|
||||
|
||||
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||
- Put each hook in its own `.ts` file
|
||||
- Put sub-components in local `components/` folder
|
||||
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E. When fixing a bug, write a failing Playwright test first (use `.fixme` annotation), implement the fix, then remove the annotation.
|
||||
6. **Code conventions**:
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||
- avoid index and barrel files
|
||||
@@ -1 +1,93 @@
|
||||
@AGENTS.md
|
||||
# CLAUDE.md - Frontend
|
||||
|
||||
This file provides guidance to Claude Code when working with the frontend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
### Pre-completion Checks (MANDATORY)
|
||||
|
||||
After making **any** code changes in the frontend, you MUST run the following commands **in order** before reporting work as done, creating commits, or opening PRs:
|
||||
|
||||
1. `pnpm format` — auto-fix formatting issues
|
||||
2. `pnpm lint` — check for lint errors; fix any that appear
|
||||
3. `pnpm types` — check for type errors; fix any that appear
|
||||
|
||||
Do NOT skip these steps. If any command reports errors, fix them and re-run until clean. Only then may you consider the task complete. If typing keeps failing, stop and ask the user.
|
||||
|
||||
### Code Style
|
||||
|
||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- No `dark:` Tailwind classes — the design system handles dark mode
|
||||
- Use Next.js `<Link>` for internal navigation — never raw `<a>` tags
|
||||
- No `any` types unless the value genuinely can be anything
|
||||
- No linter suppressors (`// @ts-ignore`, `// eslint-disable`) — fix the actual issue
|
||||
- **File length** — keep files under ~200 lines; extract sub-components or hooks into their own files when a file grows beyond this
|
||||
- **Function/component length** — keep render functions and hooks under ~50 lines; extract named helpers or sub-components when they grow longer
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
`.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Feature Development
|
||||
|
||||
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||
- Put each hook in its own `.ts` file
|
||||
- Put sub-components in local `components/` folder
|
||||
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E. When fixing a bug, write a failing Playwright test first (use `.fixme` annotation), implement the fix, then remove the annotation.
|
||||
6. **Code conventions**:
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||
- avoid index and barrel files
|
||||
|
||||
@@ -39,49 +39,39 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
return outputNodes
|
||||
.map((node) => {
|
||||
const executionResults = node.data.nodeExecutionResults || [];
|
||||
const latestResult =
|
||||
executionResults.length > 0
|
||||
? executionResults[executionResults.length - 1]
|
||||
: undefined;
|
||||
const outputData = latestResult?.output_data?.output;
|
||||
|
||||
const items = executionResults
|
||||
.filter((result) => result.output_data?.output !== undefined)
|
||||
.map((result) => {
|
||||
const outputData = result.output_data!.output;
|
||||
const renderer = globalRegistry.getRenderer(outputData);
|
||||
return {
|
||||
nodeExecID: result.node_exec_id,
|
||||
value: outputData,
|
||||
renderer,
|
||||
};
|
||||
})
|
||||
.filter(
|
||||
(
|
||||
item,
|
||||
): item is typeof item & {
|
||||
renderer: NonNullable<typeof item.renderer>;
|
||||
} => item.renderer !== null,
|
||||
);
|
||||
|
||||
if (items.length === 0) return null;
|
||||
const renderer = globalRegistry.getRenderer(outputData);
|
||||
|
||||
return {
|
||||
nodeID: node.id,
|
||||
metadata: {
|
||||
name: node.data.hardcodedValues?.name || "Output",
|
||||
description:
|
||||
node.data.hardcodedValues?.description || "Output from the agent",
|
||||
},
|
||||
items,
|
||||
value: outputData ?? "No output yet",
|
||||
renderer,
|
||||
};
|
||||
})
|
||||
.filter((group): group is NonNullable<typeof group> => group !== null);
|
||||
.filter(
|
||||
(
|
||||
output,
|
||||
): output is typeof output & {
|
||||
renderer: NonNullable<typeof output.renderer>;
|
||||
} => output.renderer !== null,
|
||||
);
|
||||
}, [nodes]);
|
||||
|
||||
const actionItems = useMemo(() => {
|
||||
return outputs.flatMap((group) =>
|
||||
group.items.map((item) => ({
|
||||
value: item.value,
|
||||
metadata: group.metadata,
|
||||
renderer: item.renderer,
|
||||
})),
|
||||
);
|
||||
return outputs.map((output) => ({
|
||||
value: output.value,
|
||||
metadata: {},
|
||||
renderer: output.renderer,
|
||||
}));
|
||||
}, [outputs]);
|
||||
|
||||
return (
|
||||
@@ -126,27 +116,24 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
<ScrollArea className="h-full overflow-auto pr-4">
|
||||
<div className="space-y-6">
|
||||
{outputs && outputs.length > 0 ? (
|
||||
outputs.map((group) => (
|
||||
<div key={group.nodeID} className="space-y-2">
|
||||
outputs.map((output, i) => (
|
||||
<div key={i} className="space-y-2">
|
||||
<div>
|
||||
<Label className="text-base font-semibold">
|
||||
{group.metadata.name || "Unnamed Output"}
|
||||
{output.metadata.name || "Unnamed Output"}
|
||||
</Label>
|
||||
{group.metadata.description && (
|
||||
{output.metadata.description && (
|
||||
<Label className="mt-1 block text-sm text-gray-600">
|
||||
{group.metadata.description}
|
||||
{output.metadata.description}
|
||||
</Label>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{group.items.map((item) => (
|
||||
<OutputItem
|
||||
key={item.nodeExecID}
|
||||
value={item.value}
|
||||
metadata={group.metadata}
|
||||
renderer={item.renderer}
|
||||
/>
|
||||
))}
|
||||
<OutputItem
|
||||
value={output.value}
|
||||
metadata={{}}
|
||||
renderer={output.renderer}
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
|
||||
@@ -33,12 +33,6 @@ export const useRunGraph = () => {
|
||||
const clearAllNodeErrors = useNodeStore(
|
||||
useShallow((state) => state.clearAllNodeErrors),
|
||||
);
|
||||
const cleanNodesStatuses = useNodeStore(
|
||||
useShallow((state) => state.cleanNodesStatuses),
|
||||
);
|
||||
const clearAllNodeExecutionResults = useNodeStore(
|
||||
useShallow((state) => state.clearAllNodeExecutionResults),
|
||||
);
|
||||
|
||||
// Tutorial integration - force open dialog when tutorial requests it
|
||||
const forceOpenRunInputDialog = useTutorialStore(
|
||||
@@ -143,9 +137,6 @@ export const useRunGraph = () => {
|
||||
if (!dryRun && (hasInputs() || hasCredentials())) {
|
||||
setOpenRunInputDialog(true);
|
||||
} else {
|
||||
// Clear stale results so the UI shows fresh output from this execution
|
||||
clearAllNodeExecutionResults();
|
||||
cleanNodesStatuses();
|
||||
// Optimistically set running state immediately for responsive UI
|
||||
setIsGraphRunning(true);
|
||||
await executeGraph({
|
||||
|
||||
@@ -10,12 +10,9 @@ import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutio
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { useGraphStore } from "../../../stores/graphStore";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { getGetV1GetExecutionDetailsQueryKey } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
|
||||
export const useFlowRealtime = () => {
|
||||
const api = useBackendAPI();
|
||||
const queryClient = useQueryClient();
|
||||
const updateNodeExecutionResult = useNodeStore(
|
||||
useShallow((state) => state.updateNodeExecutionResult),
|
||||
);
|
||||
@@ -74,16 +71,6 @@ export const useFlowRealtime = () => {
|
||||
console.debug(
|
||||
`Subscribed to updates for execution #${flowExecutionID}`,
|
||||
);
|
||||
// Refetch execution details to catch any events that were
|
||||
// published before the WebSocket subscription was established.
|
||||
// This closes the race-condition window for fast-completing
|
||||
// executions like dry-runs / simulations.
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: getGetV1GetExecutionDetailsQueryKey(
|
||||
flowID!,
|
||||
flowExecutionID,
|
||||
),
|
||||
});
|
||||
})
|
||||
.catch((error) =>
|
||||
console.error(
|
||||
@@ -100,7 +87,7 @@ export const useFlowRealtime = () => {
|
||||
deregisterGraphExecutionStatusEvent();
|
||||
resetEdgeBeads();
|
||||
};
|
||||
}, [api, flowExecutionID, resetEdgeBeads, queryClient, flowID]);
|
||||
}, [api, flowExecutionID, resetEdgeBeads]);
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { AskQuestionTool } from "../../../tools/AskQuestion/AskQuestion";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({
|
||||
</MessageResponse>
|
||||
);
|
||||
}
|
||||
case "tool-ask_question":
|
||||
return <AskQuestionTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-find_block":
|
||||
return <FindBlocksTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-find_agent":
|
||||
|
||||
@@ -13,7 +13,6 @@ export type RenderSegment =
|
||||
| { kind: "collapsed-group"; parts: ToolUIPart[] };
|
||||
|
||||
const CUSTOM_TOOL_TYPES = new Set([
|
||||
"tool-ask_question",
|
||||
"tool-find_block",
|
||||
"tool-find_agent",
|
||||
"tool-find_library_agent",
|
||||
|
||||
@@ -13,6 +13,8 @@ import {
|
||||
getSuggestionThemes,
|
||||
} from "./helpers";
|
||||
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
|
||||
import { PulseChips } from "../PulseChips/PulseChips";
|
||||
import { usePulseChips } from "../PulseChips/usePulseChips";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -34,6 +36,7 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const pulseChips = usePulseChips();
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
@@ -80,6 +83,8 @@ export function EmptySession({
|
||||
Tell me about your work — I'll find what to automate.
|
||||
</Text>
|
||||
|
||||
<PulseChips chips={pulseChips} onChipClick={onSend} />
|
||||
|
||||
<div className="mb-6">
|
||||
<motion.div
|
||||
layoutId={inputLayoutId}
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ArrowRightIcon } from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
|
||||
import type { AgentStatus } from "@/app/(platform)/library/types";
|
||||
|
||||
export interface PulseChipData {
|
||||
id: string;
|
||||
name: string;
|
||||
status: AgentStatus;
|
||||
shortMessage: string;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
chips: PulseChipData[];
|
||||
onChipClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
export function PulseChips({ chips, onChipClick }: Props) {
|
||||
if (chips.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mb-6">
|
||||
<div className="mb-3 flex items-center justify-between">
|
||||
<Text variant="small-medium" className="text-zinc-600">
|
||||
What's happening with your agents
|
||||
</Text>
|
||||
<NextLink
|
||||
href="/library"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
|
||||
>
|
||||
View all <ArrowRightIcon size={12} />
|
||||
</NextLink>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{chips.map((chip) => (
|
||||
<PulseChip key={chip.id} chip={chip} onClick={onChipClick} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ChipProps {
|
||||
chip: PulseChipData;
|
||||
onClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
function PulseChip({ chip, onClick }: ChipProps) {
|
||||
function handleClick() {
|
||||
const prompt = buildChipPrompt(chip);
|
||||
onClick?.(prompt);
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleClick}
|
||||
className="flex items-center gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2 text-left transition-all hover:border-zinc-200 hover:shadow-sm"
|
||||
>
|
||||
<StatusBadge status={chip.status} />
|
||||
<div className="min-w-0">
|
||||
<Text variant="small-medium" className="truncate text-zinc-900">
|
||||
{chip.name}
|
||||
</Text>
|
||||
<Text variant="small" className="truncate text-zinc-500">
|
||||
{chip.shortMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
function buildChipPrompt(chip: PulseChipData): string {
|
||||
switch (chip.status) {
|
||||
case "error":
|
||||
return `What happened with ${chip.name}? It has an error — can you check?`;
|
||||
case "running":
|
||||
return `Give me a status update on ${chip.name} — what has it done so far?`;
|
||||
case "idle":
|
||||
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
default:
|
||||
return `Tell me about ${chip.name} — what's its current status?`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { PulseChipData } from "./PulseChips";
|
||||
import type { AgentStatus } from "@/app/(platform)/library/types";
|
||||
|
||||
/**
|
||||
* Provides a prioritised list of pulse chips for the Home empty state.
|
||||
* Errors → running → stale, max 5 chips.
|
||||
*
|
||||
* TODO: Replace with real API data from `GET /agents/summary` or similar.
|
||||
*/
|
||||
export function usePulseChips(): PulseChipData[] {
|
||||
const [chips] = useState<PulseChipData[]>(() => MOCK_CHIPS);
|
||||
return chips;
|
||||
}
|
||||
|
||||
const MOCK_CHIPS: PulseChipData[] = [
|
||||
{
|
||||
id: "chip-1",
|
||||
name: "Lead Finder",
|
||||
status: "error" as AgentStatus,
|
||||
shortMessage: "API rate limit hit",
|
||||
},
|
||||
{
|
||||
id: "chip-2",
|
||||
name: "CEO Finder",
|
||||
status: "running" as AgentStatus,
|
||||
shortMessage: "72% complete",
|
||||
},
|
||||
{
|
||||
id: "chip-3",
|
||||
name: "Cart Recovery",
|
||||
status: "idle" as AgentStatus,
|
||||
shortMessage: "No runs in 3 weeks",
|
||||
},
|
||||
{
|
||||
id: "chip-4",
|
||||
name: "Social Collector",
|
||||
status: "listening" as AgentStatus,
|
||||
shortMessage: "Waiting for trigger",
|
||||
},
|
||||
];
|
||||
@@ -1,68 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ChatTeardropDotsIcon, WarningCircleIcon } from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { ClarificationQuestionsCard } from "../../components/ClarificationQuestionsCard/ClarificationQuestionsCard";
|
||||
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { normalizeClarifyingQuestions } from "../clarifying-questions";
|
||||
import {
|
||||
getAnimationText,
|
||||
getAskQuestionOutput,
|
||||
isClarificationOutput,
|
||||
isErrorOutput,
|
||||
} from "./helpers";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
export function AskQuestionTool({ part }: Props) {
|
||||
const text = getAnimationText(part);
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output = getAskQuestionOutput(part);
|
||||
|
||||
function handleAnswers(answers: Record<string, string>) {
|
||||
if (!output || !isClarificationOutput(output)) return;
|
||||
const questions = normalizeClarifyingQuestions(output.questions ?? []);
|
||||
const message = questions
|
||||
.map((q) => {
|
||||
const answer = answers[q.keyword] || "";
|
||||
return `> ${q.question}\n\n${answer}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
onSend(`**Here are my answers:**\n\n${message}\n\nPlease proceed.`);
|
||||
}
|
||||
|
||||
if (output && isClarificationOutput(output)) {
|
||||
return (
|
||||
<ClarificationQuestionsCard
|
||||
questions={normalizeClarifyingQuestions(output.questions ?? [])}
|
||||
message={output.message}
|
||||
sessionId={output.session_id}
|
||||
onSubmitAnswers={handleAnswers}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-2 py-2 text-sm text-muted-foreground">
|
||||
{isError || (output && isErrorOutput(output)) ? (
|
||||
<WarningCircleIcon size={16} className="text-red-500" />
|
||||
) : isStreaming ? (
|
||||
<ChatTeardropDotsIcon size={16} className="animate-pulse" />
|
||||
) : (
|
||||
<ChatTeardropDotsIcon size={16} />
|
||||
)}
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import type { ToolUIPart } from "ai";
|
||||
|
||||
interface ClarifyingQuestionPayload {
|
||||
question: string;
|
||||
keyword: string;
|
||||
example?: string;
|
||||
}
|
||||
|
||||
export interface AskQuestionOutput {
|
||||
type: string;
|
||||
message: string;
|
||||
questions: ClarifyingQuestionPayload[];
|
||||
session_id?: string;
|
||||
}
|
||||
|
||||
interface ErrorOutput {
|
||||
type: "error";
|
||||
message: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type AskQuestionToolOutput = AskQuestionOutput | ErrorOutput;
|
||||
|
||||
function parseOutput(output: unknown): AskQuestionToolOutput | null {
|
||||
if (!output) return null;
|
||||
if (typeof output === "string") {
|
||||
try {
|
||||
return parseOutput(JSON.parse(output) as unknown);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (typeof output === "object" && output !== null) {
|
||||
const obj = output as Record<string, unknown>;
|
||||
if (
|
||||
obj.type === ResponseType.agent_builder_clarification_needed ||
|
||||
"questions" in obj
|
||||
) {
|
||||
return obj as unknown as AskQuestionOutput;
|
||||
}
|
||||
if (obj.type === "error" || "error" in obj) {
|
||||
return obj as unknown as ErrorOutput;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getAskQuestionOutput(
|
||||
part: ToolUIPart,
|
||||
): AskQuestionToolOutput | null {
|
||||
return parseOutput(part.output);
|
||||
}
|
||||
|
||||
export function isClarificationOutput(
|
||||
output: AskQuestionToolOutput,
|
||||
): output is AskQuestionOutput {
|
||||
return (
|
||||
output.type === ResponseType.agent_builder_clarification_needed ||
|
||||
"questions" in output
|
||||
);
|
||||
}
|
||||
|
||||
export function isErrorOutput(
|
||||
output: AskQuestionToolOutput,
|
||||
): output is ErrorOutput {
|
||||
return output.type === "error" || "error" in output;
|
||||
}
|
||||
|
||||
export function getAnimationText(part: ToolUIPart): string {
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return "Asking question...";
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (output && isClarificationOutput(output)) return "Needs your input";
|
||||
if (output && isErrorOutput(output)) return "Failed to ask question";
|
||||
return "Asking question...";
|
||||
}
|
||||
case "output-error":
|
||||
return "Failed to ask question";
|
||||
default:
|
||||
return "Asking question...";
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,13 @@ import {
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import { ClarificationQuestionsCard } from "./components/ClarificationQuestionsCard";
|
||||
import { MiniGame } from "../../components/MiniGame/MiniGame";
|
||||
import { SuggestedGoalCard } from "./components/SuggestedGoalCard";
|
||||
import {
|
||||
buildClarificationAnswersMessage,
|
||||
normalizeClarifyingQuestions,
|
||||
} from "../clarifying-questions";
|
||||
import {
|
||||
AccordionIcon,
|
||||
formatMaybeJson,
|
||||
@@ -22,6 +27,7 @@ import {
|
||||
getCreateAgentToolOutput,
|
||||
isAgentPreviewOutput,
|
||||
isAgentSavedOutput,
|
||||
isClarificationNeededOutput,
|
||||
isErrorOutput,
|
||||
isSuggestedGoalOutput,
|
||||
ToolIcon,
|
||||
@@ -60,6 +66,15 @@ function getAccordionMeta(output: CreateAgentToolOutput | null) {
|
||||
description: `${output.node_count} block${output.node_count === 1 ? "" : "s"}`,
|
||||
};
|
||||
}
|
||||
if (isClarificationNeededOutput(output)) {
|
||||
const questions = output.questions ?? [];
|
||||
return {
|
||||
icon,
|
||||
title: "Needs clarification",
|
||||
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
|
||||
expanded: true,
|
||||
};
|
||||
}
|
||||
if (isSuggestedGoalOutput(output)) {
|
||||
return {
|
||||
icon,
|
||||
@@ -92,6 +107,15 @@ export function CreateAgentTool({ part }: Props) {
|
||||
onSend(`Please create an agent with this goal: ${goal}`);
|
||||
}
|
||||
|
||||
function handleClarificationAnswers(answers: Record<string, string>) {
|
||||
const questions =
|
||||
output && isClarificationNeededOutput(output)
|
||||
? (output.questions ?? [])
|
||||
: [];
|
||||
|
||||
onSend(buildClarificationAnswersMessage(answers, questions, "create"));
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
{isOperating && (
|
||||
@@ -124,42 +148,44 @@ export function CreateAgentTool({ part }: Props) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasExpandableContent && !(output && isAgentSavedOutput(output)) && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
<ContentHint>
|
||||
This could take a few minutes — play while you wait!
|
||||
</ContentHint>
|
||||
</ContentGrid>
|
||||
)}
|
||||
{hasExpandableContent &&
|
||||
!(output && isClarificationNeededOutput(output)) &&
|
||||
!(output && isAgentSavedOutput(output)) && (
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isOperating && (
|
||||
<ContentGrid>
|
||||
<MiniGame />
|
||||
<ContentHint>
|
||||
This could take a few minutes — play while you wait!
|
||||
</ContentHint>
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{output && isAgentPreviewOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.description?.trim() && (
|
||||
<ContentCardDescription>
|
||||
{output.description}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
<ContentCodeBlock>
|
||||
{truncateText(formatMaybeJson(output.agent_json), 1600)}
|
||||
</ContentCodeBlock>
|
||||
</ContentGrid>
|
||||
)}
|
||||
{output && isAgentPreviewOutput(output) && (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
{output.description?.trim() && (
|
||||
<ContentCardDescription>
|
||||
{output.description}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
<ContentCodeBlock>
|
||||
{truncateText(formatMaybeJson(output.agent_json), 1600)}
|
||||
</ContentCodeBlock>
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{output && isSuggestedGoalOutput(output) && (
|
||||
<SuggestedGoalCard
|
||||
message={output.message}
|
||||
suggestedGoal={output.suggested_goal}
|
||||
reason={output.reason}
|
||||
goalType={output.goal_type ?? "vague"}
|
||||
onUseSuggestedGoal={handleUseSuggestedGoal}
|
||||
/>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
{output && isSuggestedGoalOutput(output) && (
|
||||
<SuggestedGoalCard
|
||||
message={output.message}
|
||||
suggestedGoal={output.suggested_goal}
|
||||
reason={output.reason}
|
||||
goalType={output.goal_type ?? "vague"}
|
||||
onUseSuggestedGoal={handleUseSuggestedGoal}
|
||||
/>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
|
||||
{output && isAgentSavedOutput(output) && (
|
||||
<AgentSavedCard
|
||||
@@ -169,6 +195,14 @@ export function CreateAgentTool({ part }: Props) {
|
||||
agentPageLink={output.agent_page_link}
|
||||
/>
|
||||
)}
|
||||
|
||||
{output && isClarificationNeededOutput(output) && (
|
||||
<ClarificationQuestionsCard
|
||||
questions={normalizeClarifyingQuestions(output.questions ?? [])}
|
||||
message={output.message}
|
||||
onSubmitAnswers={handleClarificationAnswers}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user