mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
412 Commits
feat/agent
...
combined-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64d9d6d880 | ||
|
|
9fc324e28a | ||
|
|
adf66bdd24 | ||
|
|
dc10ad715a | ||
|
|
493c91e0dd | ||
|
|
b278e66f4d | ||
|
|
3e183ed2a3 | ||
|
|
82887a2d92 | ||
|
|
993c43b623 | ||
|
|
13fcc62a31 | ||
|
|
8fefa23468 | ||
|
|
749a56ca20 | ||
|
|
a8a62eeefc | ||
|
|
173614bcc5 | ||
|
|
3396cb3f4c | ||
|
|
0c5d628b74 | ||
|
|
ed40549499 | ||
|
|
fbe634fb19 | ||
|
|
a338c72c42 | ||
|
|
a9d13f0cbf | ||
|
|
e83e50a8f1 | ||
|
|
7f4398efa3 | ||
|
|
c2a054c511 | ||
|
|
b256560619 | ||
|
|
c63d5f538b | ||
|
|
eeba884671 | ||
|
|
90822e3f37 | ||
|
|
a8bb6b5544 | ||
|
|
83b00f4789 | ||
|
|
4cd53bb7f6 | ||
|
|
96d83e9bbd | ||
|
|
e99f4ac767 | ||
|
|
67c2540177 | ||
|
|
0da949ba42 | ||
|
|
95524e94b3 | ||
|
|
eda02f9ce6 | ||
|
|
9ab6082a23 | ||
|
|
2c517ff9a1 | ||
|
|
7020ae2189 | ||
|
|
a49ac5ba13 | ||
|
|
2a969e5018 | ||
|
|
79005b1be5 | ||
|
|
4f8cdbee47 | ||
|
|
3ed444dd60 | ||
|
|
83e747ebcd | ||
|
|
827f2b0f87 | ||
|
|
b0d5d3b95e | ||
|
|
eb9244be1a | ||
|
|
dd17e83299 | ||
|
|
74009bedac | ||
|
|
72d0c8dad8 | ||
|
|
e860f164e4 | ||
|
|
b9336984be | ||
|
|
9924dedddc | ||
|
|
c054799b4f | ||
|
|
004d3957b3 | ||
|
|
f3b5d584a3 | ||
|
|
476d9dcf80 | ||
|
|
072b623f8b | ||
|
|
a68f48e6b7 | ||
|
|
60e2474640 | ||
|
|
a892bbd4dd | ||
|
|
538e8619da | ||
|
|
4edb1f6e4a | ||
|
|
480d58607d | ||
|
|
8561eb35f2 | ||
|
|
0b4acd73f4 | ||
|
|
e9fe2991d6 | ||
|
|
26b0c95936 | ||
|
|
735965bbe5 | ||
|
|
a8f9ed0f60 | ||
|
|
308357de84 | ||
|
|
1a6c50c6cc | ||
|
|
9391dfa4b2 | ||
|
|
6b031085bd | ||
|
|
6a69d7c68d | ||
|
|
ad77e881c9 | ||
|
|
f1aedfeedd | ||
|
|
49c7ab4011 | ||
|
|
2d04584c84 | ||
|
|
2578f61abb | ||
|
|
927c6e7db0 | ||
|
|
f753e6162f | ||
|
|
b996bc556b | ||
|
|
e4f79261c1 | ||
|
|
09bc939498 | ||
|
|
79c5a10f75 | ||
|
|
2bf5a37646 | ||
|
|
d5d24e6e66 | ||
|
|
c9cbd7531e | ||
|
|
289a19d402 | ||
|
|
7800af1835 | ||
|
|
114f91ff53 | ||
|
|
63a0153e4f | ||
|
|
1364616ff1 | ||
|
|
4e9169c1a2 | ||
|
|
705e97ec46 | ||
|
|
8cea0bede0 | ||
|
|
5e52050788 | ||
|
|
0b77af29aa | ||
|
|
bd4cc21fc6 | ||
|
|
19ea753639 | ||
|
|
07fd734fa1 | ||
|
|
8a4a16ec5c | ||
|
|
5551da674e | ||
|
|
e57e48272a | ||
|
|
6f03ceeb88 | ||
|
|
554ff0b20b | ||
|
|
c2f421cb42 | ||
|
|
dd228de17d | ||
|
|
c26ff22f9c | ||
|
|
760360fbe9 | ||
|
|
e3d589b180 | ||
|
|
913d93f47c | ||
|
|
03e5d37dc4 | ||
|
|
6e2dab413e | ||
|
|
b10dc7c2d5 | ||
|
|
8de935c84b | ||
|
|
dd34b0dc48 | ||
|
|
015e0d591e | ||
|
|
2cb65f5c34 | ||
|
|
3a49086c3d | ||
|
|
0e567df1da | ||
|
|
b5b754d5eb | ||
|
|
456bb1c4d0 | ||
|
|
263cd0ecac | ||
|
|
66afca6e0c | ||
|
|
11b846dd49 | ||
|
|
a71396ee48 | ||
|
|
beb43bb847 | ||
|
|
a55653f8c1 | ||
|
|
f3dd708cf6 | ||
|
|
c4ff31c79c | ||
|
|
9f2257daaa | ||
|
|
925e9a047c | ||
|
|
3e6faf2de7 | ||
|
|
40a1f504c0 | ||
|
|
22e8c5c353 | ||
|
|
1de2a7fb09 | ||
|
|
b3d9e9e856 | ||
|
|
48b166a82c | ||
|
|
697b15ce81 | ||
|
|
5beabf936c | ||
|
|
b9e29c96bd | ||
|
|
32bfe1b209 | ||
|
|
62302db470 | ||
|
|
89c7f34d26 | ||
|
|
543fc2da70 | ||
|
|
7f986bc565 | ||
|
|
f4571cb9e1 | ||
|
|
5f41afe748 | ||
|
|
d046c01a65 | ||
|
|
b220fe4347 | ||
|
|
7af138adba | ||
|
|
5c406a20ba | ||
|
|
61513b9dad | ||
|
|
6f679a0e32 | ||
|
|
b8065212b1 | ||
|
|
d5281a9a13 | ||
|
|
05495d8478 | ||
|
|
bae409d04e | ||
|
|
e11eb2caaa | ||
|
|
2c04768711 | ||
|
|
c9bf3aa339 | ||
|
|
4ac0ba570a | ||
|
|
d61a2c6cd0 | ||
|
|
1c301b4b61 | ||
|
|
e753aee7a0 | ||
|
|
f76566c834 | ||
|
|
a58b997141 | ||
|
|
3f24a003ad | ||
|
|
1a645e1e37 | ||
|
|
bee76962b0 | ||
|
|
864e68bed1 | ||
|
|
7c6201110c | ||
|
|
bded680b77 | ||
|
|
1e008dc172 | ||
|
|
9966e122ab | ||
|
|
65108c31dc | ||
|
|
7767c97f50 | ||
|
|
69ab21ebe7 | ||
|
|
6fe4e1b774 | ||
|
|
c778cc9849 | ||
|
|
50b635da6d | ||
|
|
08e254143b | ||
|
|
89fcfc4e0a | ||
|
|
e7ca07f4bf | ||
|
|
c564ac7277 | ||
|
|
ac3a826ad0 | ||
|
|
6f32184019 | ||
|
|
6d0eedae83 | ||
|
|
fb328f9d74 | ||
|
|
a369fbe169 | ||
|
|
2a0b74cae4 | ||
|
|
b08f9fc02a | ||
|
|
857acb2bbc | ||
|
|
0cb230c4f0 | ||
|
|
2cd5c0eab8 | ||
|
|
7bf8e460ea | ||
|
|
84d328517a | ||
|
|
842ff6c600 | ||
|
|
b510fbee2a | ||
|
|
bb7f0ad1f2 | ||
|
|
3f8af89b63 | ||
|
|
375e5e1f10 | ||
|
|
fd1d706315 | ||
|
|
faf2f43f6a | ||
|
|
eea230d37f | ||
|
|
76965429f1 | ||
|
|
eefa60368f | ||
|
|
88fe1e9b5e | ||
|
|
93264b1177 | ||
|
|
3269d17880 | ||
|
|
1e5788f2cf | ||
|
|
ca8214d95f | ||
|
|
f58ce5cc70 | ||
|
|
bf29801b07 | ||
|
|
dcc2bdd8ab | ||
|
|
e74a918c4a | ||
|
|
ff05b5b8d5 | ||
|
|
56090f870c | ||
|
|
a3e3d3ff6b | ||
|
|
cfaa1ff0d4 | ||
|
|
6ab9a3285f | ||
|
|
390324d5a1 | ||
|
|
9f51796dbe | ||
|
|
f42f0013df | ||
|
|
3154e5b87a | ||
|
|
78cd14d501 | ||
|
|
137edb3e6e | ||
|
|
449e9b17f1 | ||
|
|
5b3f87d7c7 | ||
|
|
ee7209a575 | ||
|
|
7ea89b07ce | ||
|
|
5324e0cc2f | ||
|
|
c7cbb8b02e | ||
|
|
d66ffb1ee4 | ||
|
|
5d489c72b5 | ||
|
|
646ffe1693 | ||
|
|
59b1811e8b | ||
|
|
6404e58fb1 | ||
|
|
e0bfa1524e | ||
|
|
ac947a0c11 | ||
|
|
c9f45f056a | ||
|
|
89264091ad | ||
|
|
e3183f1955 | ||
|
|
3ea243c760 | ||
|
|
991969612c | ||
|
|
8de9880f43 | ||
|
|
86d8efe697 | ||
|
|
10ec6c7215 | ||
|
|
51e5371362 | ||
|
|
cdd14726ce | ||
|
|
1ebd5635f6 | ||
|
|
349b6c63de | ||
|
|
2f7cfa6f1b | ||
|
|
049aa1ad7d | ||
|
|
a16be2675b | ||
|
|
ac416a561e | ||
|
|
c47fcc1925 | ||
|
|
77fd8648a7 | ||
|
|
4842599bec | ||
|
|
339e155823 | ||
|
|
9344e62d66 | ||
|
|
ee6cc20cbc | ||
|
|
eb96b019c5 | ||
|
|
9cf6ac9ad9 | ||
|
|
d3173605eb | ||
|
|
98c27653f2 | ||
|
|
dced534df3 | ||
|
|
4ebe294707 | ||
|
|
2e8e115cd1 | ||
|
|
5ca49a8ec9 | ||
|
|
a9db5af0fa | ||
|
|
dcbfcfb158 | ||
|
|
723b852ba4 | ||
|
|
c7e0f8169a | ||
|
|
ce1555c07a | ||
|
|
403a36a3fc | ||
|
|
490643d65a | ||
|
|
2b14ecf5ee | ||
|
|
14d6d66bdc | ||
|
|
28443e2e33 | ||
|
|
611a20d7df | ||
|
|
ce201cd19c | ||
|
|
0c76852768 | ||
|
|
414b8bbaac | ||
|
|
4c85f2399a | ||
|
|
db0e5a1b0b | ||
|
|
22a5e76af9 | ||
|
|
7919da16b4 | ||
|
|
052f953afb | ||
|
|
abd9fbe08a | ||
|
|
81308af770 | ||
|
|
a726c1d1d5 | ||
|
|
a015bf9e1c | ||
|
|
d99278a40d | ||
|
|
bd7d9a5697 | ||
|
|
9cfa53a2ff | ||
|
|
e6cf899a6d | ||
|
|
b655b30aeb | ||
|
|
5b8daf5d4c | ||
|
|
9b74b7bb41 | ||
|
|
a1578984cc | ||
|
|
c0869e9168 | ||
|
|
0db5a6ff9a | ||
|
|
3664624445 | ||
|
|
f1e2ce0703 | ||
|
|
c226cf0925 | ||
|
|
dade634b4a | ||
|
|
34101c4389 | ||
|
|
2218254c8a | ||
|
|
4d63cffa7a | ||
|
|
ebf3b920d8 | ||
|
|
9bd579b041 | ||
|
|
41601cbb5c | ||
|
|
c636b6f310 | ||
|
|
292be77b86 | ||
|
|
dd3349e6bc | ||
|
|
bfdf4b99db | ||
|
|
aba78b0fdd | ||
|
|
12934dfd72 | ||
|
|
c5507415fd | ||
|
|
7ff096afd9 | ||
|
|
38fb504063 | ||
|
|
b4388a9c93 | ||
|
|
a7a68e585a | ||
|
|
14ad37b0c7 | ||
|
|
24d0c35ed3 | ||
|
|
389cd28879 | ||
|
|
656858eba1 | ||
|
|
8aae7751dc | ||
|
|
725da7e887 | ||
|
|
bd9e9ec614 | ||
|
|
88589764b5 | ||
|
|
f0a3afda7d | ||
|
|
a9cbb3ee2f | ||
|
|
1810452920 | ||
|
|
4f6f3ca240 | ||
|
|
9ffecbac02 | ||
|
|
eb22cf4483 | ||
|
|
16636b64c6 | ||
|
|
c2709fbc28 | ||
|
|
c659f3b058 | ||
|
|
80581a8364 | ||
|
|
3c046eb291 | ||
|
|
3adbaacc0e | ||
|
|
4da3535a9c | ||
|
|
3e25488b2d | ||
|
|
56e0b568a4 | ||
|
|
4acac9ff5b | ||
|
|
0b0777ac87 | ||
|
|
698b1599cb | ||
|
|
a2f94f08d9 | ||
|
|
0c6f20f728 | ||
|
|
d100b2515b | ||
|
|
14113f96a9 | ||
|
|
ee40a4b9a8 | ||
|
|
0008cafc3b | ||
|
|
f55bc84fe7 | ||
|
|
3cfee4c4b5 | ||
|
|
c48b5239b9 | ||
|
|
e44615f8b8 | ||
|
|
22f0da0a03 | ||
|
|
9264b42050 | ||
|
|
3a40188024 | ||
|
|
8d6433c1a5 | ||
|
|
c7430eaffb | ||
|
|
dc272559c6 | ||
|
|
a98b0aee95 | ||
|
|
264869cab9 | ||
|
|
a85ba9e36d | ||
|
|
18c5f67107 | ||
|
|
0348e7b228 | ||
|
|
e35376d3ec | ||
|
|
687af1bdc3 | ||
|
|
694032e45f | ||
|
|
231a4b6f51 | ||
|
|
da6f77da47 | ||
|
|
1747f4e6f3 | ||
|
|
0d6d8e820c | ||
|
|
24c286fbed | ||
|
|
c75f1ff749 | ||
|
|
cfc6d3538c | ||
|
|
e9540041d6 | ||
|
|
8ac86a03b5 | ||
|
|
2aac78eae4 | ||
|
|
dbfc791357 | ||
|
|
880c957c86 | ||
|
|
857a8ef0aa | ||
|
|
1008f9fcd4 | ||
|
|
c26791e6ae | ||
|
|
cf66c08125 | ||
|
|
b4362785e4 | ||
|
|
f38fa96df4 | ||
|
|
98c8f94ef2 | ||
|
|
7b0111d9b5 | ||
|
|
85e9e4c5b7 | ||
|
|
e900ee615a | ||
|
|
e1d5113051 | ||
|
|
4963d227ea | ||
|
|
19dea0e4ca | ||
|
|
87d5a39267 | ||
|
|
87ac8148e3 | ||
|
|
491132f62f | ||
|
|
55815a3207 | ||
|
|
5c3aa11600 | ||
|
|
b5cbf8505b | ||
|
|
f49f63de76 | ||
|
|
8f76384942 | ||
|
|
ffb8d366d6 | ||
|
|
432ef5ab5e |
1
.agents/skills
Symbolic link
1
.agents/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
|
||||
120
autogpt_platform/AGENTS.md
Normal file
120
autogpt_platform/AGENTS.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1,120 +1 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@AGENTS.md
|
||||
|
||||
227
autogpt_platform/backend/AGENTS.md
Normal file
227
autogpt_platform/backend/AGENTS.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1,227 +1 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@AGENTS.md
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/platform_costs/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: typing.Optional[datetime] = Query(None),
|
||||
end: typing.Optional[datetime] = Query(None),
|
||||
provider: typing.Optional[str] = Query(None),
|
||||
user_id: typing.Optional[str] = Query(None),
|
||||
):
|
||||
logger.info(f"Admin {admin_user_id} fetching platform cost dashboard")
|
||||
return await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/platform_costs/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: typing.Optional[datetime] = Query(None),
|
||||
end: typing.Optional[datetime] = Query(None),
|
||||
provider: typing.Optional[str] = Query(None),
|
||||
user_id: typing.Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info(f"Admin {admin_user_id} fetching platform cost logs")
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_dashboard = AsyncMock(
|
||||
return_value=AsyncMock(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
model_dump=lambda **_: {
|
||||
"by_provider": [],
|
||||
"by_user": [],
|
||||
"total_cost_microdollars": 0,
|
||||
"total_requests": 0,
|
||||
"total_users": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get("/admin/platform_costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/platform_costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_dashboard = AsyncMock(
|
||||
return_value=AsyncMock(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
model_dump=lambda **_: {
|
||||
"by_provider": [],
|
||||
"by_user": [],
|
||||
"total_cost_microdollars": 0,
|
||||
"total_requests": 0,
|
||||
"total_users": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/platform_costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/platform_costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
app.dependency_overrides.clear()
|
||||
response = client.get("/admin/platform_costs/dashboard")
|
||||
assert response.status_code in (401, 403)
|
||||
@@ -9,11 +9,14 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
@@ -143,4 +158,96 @@ async def reset_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only."""
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
|
||||
# Resolve email for audit logging (non-blocking — don't fail the
|
||||
# tier change if email lookup fails).
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s", request.user_id, exc_info=True
|
||||
)
|
||||
resolved_email = None
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email or "unknown",
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -89,6 +89,7 @@ def test_get_rate_limit(
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -261,3 +264,294 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_non_blocking(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure doesn't block tier change."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_set.assert_awaited_once()
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -4,14 +4,14 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -20,6 +20,7 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -110,6 +111,23 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: Literal["fast", "extended_thinking"] | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -118,6 +136,7 @@ class CreateSessionResponse(BaseModel):
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -138,6 +157,7 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -248,6 +268,7 @@ async def list_sessions(
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -256,22 +277,28 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -420,6 +447,7 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -433,8 +461,9 @@ async def get_copilot_usage(
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
@@ -442,6 +471,7 @@ async def get_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -493,7 +523,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -527,10 +557,13 @@ async def reset_copilot_usage(
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -606,6 +639,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -716,7 +750,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
@@ -811,6 +845,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -1174,7 +1209,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -331,14 +332,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -469,3 +484,60 @@ def test_suggested_prompts_empty_prompts(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -481,6 +481,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
include=library_agent_include(
|
||||
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/platform-costs",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -358,6 +359,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
@@ -565,6 +567,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
@@ -760,4 +763,5 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(organizations)))
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,5 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(people)))
|
||||
yield "people", people
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.blocks.apollo._auth import (
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
@@ -141,4 +141,5 @@ class GetPersonDetailBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
|
||||
@@ -146,6 +146,21 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -232,11 +247,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -367,7 +382,9 @@ class AutoPilotBlock(Block):
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id, dry_run=input_data.dry_run
|
||||
)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -342,6 +343,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
@@ -467,6 +469,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
@@ -577,6 +580,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
|
||||
@@ -15,7 +15,12 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import (
|
||||
@@ -195,6 +200,7 @@ class GetLinkedinProfileBlock(Block):
|
||||
include_social_media=input_data.include_social_media,
|
||||
include_extra=input_data.include_extra,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "profile", profile
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
|
||||
@@ -341,6 +347,7 @@ class LinkedinPersonLookupBlock(Block):
|
||||
include_similarity_checks=input_data.include_similarity_checks,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "lookup_result", lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
|
||||
@@ -443,6 +450,7 @@ class LinkedinRoleLookupBlock(Block):
|
||||
company_name=input_data.company_name,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "role_lookup_result", role_lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up role in company: {str(e)}")
|
||||
@@ -523,6 +531,7 @@ class GetLinkedinProfilePictureBlock(Block):
|
||||
credentials=credentials,
|
||||
linkedin_profile_url=input_data.linkedin_profile_url,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "profile_picture_url", profile_picture
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile picture: {str(e)}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from backend.blocks.fal._auth import (
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
from backend.util.type import MediaFileType
|
||||
@@ -230,6 +230,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,7 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(places)))
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -227,6 +228,7 @@ class IdeogramModelBlock(Block):
|
||||
image_url=result,
|
||||
)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "result", result
|
||||
|
||||
async def run_model(
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -687,6 +687,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -1045,6 +1046,16 @@ async def llm_call(
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
cost = None
|
||||
try:
|
||||
raw_resp = getattr(response, "_response", None)
|
||||
if raw_resp and hasattr(raw_resp, "headers"):
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if cost_header:
|
||||
cost = float(cost_header)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1053,6 +1064,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=cost,
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1377,12 +1389,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
cost_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
if llm_response.provider_cost is not None:
|
||||
cost_stats.provider_cost = llm_response.provider_cost
|
||||
self.merge_stats(cost_stats)
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -153,6 +154,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
messages,
|
||||
**params,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
results = result.get("results", [])
|
||||
yield "results", results
|
||||
@@ -255,6 +257,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
result: list[dict[str, Any]] = client.search(
|
||||
input_data.query, version="v2", filters=filters
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "memories", result
|
||||
|
||||
except Exception as e:
|
||||
@@ -340,6 +343,7 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "memories", memories
|
||||
|
||||
@@ -434,6 +438,7 @@ class GetLatestMemoryBlock(Block, Mem0Base):
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
if memories:
|
||||
# Return the latest memory (first in the list as they're sorted by recency)
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.nvidia._auth import (
|
||||
NvidiaCredentialsField,
|
||||
NvidiaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -69,6 +69,7 @@ class NvidiaDeepfakeDetectBlock(Block):
|
||||
data = response.json()
|
||||
|
||||
result = data.get("data", [{}])[0]
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
# Get deepfake probability from first bounding box if any
|
||||
deepfake_prob = 0.0
|
||||
|
||||
@@ -17,7 +17,12 @@ from backend.blocks.replicate._auth import (
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,6 +113,7 @@ class ReplicateModelBlock(Block):
|
||||
result = await self.run_model(
|
||||
model_ref, input_data.model_inputs, credentials.api_key
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "result", result
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -185,6 +186,7 @@ class ScreenshotWebPageBlock(Block):
|
||||
block_chats=input_data.block_chats,
|
||||
cache=input_data.cache,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "image", screenshot_data["image"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -146,6 +147,7 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
weather_data = await self.get_request(url, json=True)
|
||||
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -100,6 +100,7 @@ class CreateCampaignBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response = await self.create_campaign(input_data.name, credentials)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "id", response.id
|
||||
yield "name", response.name
|
||||
@@ -226,6 +227,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(input_data.lead_list)))
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
@@ -321,6 +323,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
response = await self.save_campaign_sequences(
|
||||
input_data.campaign_id, input_data.sequences, credentials
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
if response.data:
|
||||
yield "data", response.data
|
||||
|
||||
304
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
304
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description="Database hostname or IP address",
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield "error", self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield "error", (
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1396
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1396
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
376
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
376
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect."""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms).
|
||||
# pymssql's connect_args "login_timeout" handles the connection
|
||||
# timeout, but LOCK_TIMEOUT covers in-query lock waits.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a query inside an explicit transaction, returning results."""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(connection_url, connect_args={timeout_key: 10})
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -15,6 +15,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -181,6 +182,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
|
||||
@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,5 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(input_data.text)))
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -19,6 +19,7 @@ from backend.blocks._base import (
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
@@ -170,6 +171,7 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
# Only yield after all operations succeed
|
||||
yield "video_id", video_id
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentials,
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -177,5 +177,6 @@ class ValidateEmailsBlock(Block):
|
||||
)
|
||||
|
||||
response_model = Response(**response.__dict__)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "response", response_model
|
||||
|
||||
@@ -18,6 +18,7 @@ import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.context import set_execution_context
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -50,6 +51,12 @@ from backend.copilot.service import (
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
@@ -107,7 +114,7 @@ async def _baseline_llm_caller(
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
@@ -115,7 +122,7 @@ async def _baseline_llm_caller(
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
@@ -281,6 +288,9 @@ def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
@@ -300,6 +310,29 @@ def _baseline_conversation_updater(
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
# Record assistant message (with tool_calls) to transcript
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if response.response_text:
|
||||
content_blocks.append({"type": "text", "text": response.response_text})
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except Exception:
|
||||
args = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content_blocks:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=content_blocks,
|
||||
model=model,
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
@@ -308,9 +341,22 @@ def _baseline_conversation_updater(
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
# Record tool result to transcript AFTER the assistant tool_use
|
||||
# block to maintain correct Anthropic API ordering:
|
||||
# assistant(tool_use) → user(tool_result)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tr.tool_call_id,
|
||||
content=tr.content,
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
# Record final text to transcript
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -339,19 +385,23 @@ async def _compress_session_messages(
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
@@ -365,7 +415,12 @@ async def _compress_session_messages(
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
@@ -396,7 +451,8 @@ async def stream_chat_completion_baseline(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
# Append user message (skip if it's an exact duplicate of the last message,
|
||||
# e.g. from a network retry)
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
@@ -415,6 +471,54 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
if user_id and len(session.messages) > 1:
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
if dl and validate_transcript(dl.content):
|
||||
# Reject stale transcripts: if msg_count is known and
|
||||
# doesn't cover the current session, loading it would
|
||||
# silently drop intermediate turns from the transcript.
|
||||
session_msg_count = len(session.messages)
|
||||
if dl.message_count and dl.message_count < session_msg_count - 1:
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
transcript_builder.load_previous(
|
||||
dl.content, log_prefix="[Baseline]"
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
elif dl:
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
transcript_covers_prefix = False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
transcript_covers_prefix = False
|
||||
|
||||
# Append user message to transcript.
|
||||
# Always append when the message is present and is from the user,
|
||||
# even on duplicate-suppressed retries (is_new_message=False).
|
||||
# The loaded transcript may be stale (uploaded before the previous
|
||||
# attempt stored this message), so skipping it would leave the
|
||||
# transcript without the user turn, creating a malformed
|
||||
# assistant-after-assistant structure when the LLM reply is added.
|
||||
if message and is_user_message:
|
||||
transcript_builder.append_user(content=message)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -447,16 +551,37 @@ async def stream_chat_completion_baseline(
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
# Build OpenAI message list from session history.
|
||||
# Include tool_calls on assistant messages and tool-role results so the
|
||||
# model retains full context of what tools were invoked and their outcomes.
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
if msg.role == "assistant":
|
||||
entry: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
entry["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
entry["tool_calls"] = msg.tool_calls
|
||||
if msg.content or msg.tool_calls:
|
||||
openai_messages.append(entry)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"content": msg.content or "",
|
||||
}
|
||||
)
|
||||
elif msg.role == "user" and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
# Propagate execution context so tool handlers can read session-level flags.
|
||||
set_execution_context(user_id, session)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
@@ -483,6 +608,12 @@ async def stream_chat_completion_baseline(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
_bound_conversation_updater = partial(
|
||||
_baseline_conversation_updater,
|
||||
transcript_builder=transcript_builder,
|
||||
model=config.fast_model,
|
||||
)
|
||||
|
||||
try:
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
@@ -490,7 +621,7 @@ async def stream_chat_completion_baseline(
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
@@ -559,10 +690,10 @@ async def stream_chat_completion_baseline(
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
estimate_token_count(openai_messages, model=config.fast_model), 1
|
||||
)
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
state.assistant_text, model=config.fast_model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
@@ -593,6 +724,25 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# --- Upload transcript for next-turn continuity ---
|
||||
if user_id and transcript_covers_prefix:
|
||||
try:
|
||||
_transcript_content = transcript_builder.to_jsonl()
|
||||
if _transcript_content and validate_transcript(_transcript_content):
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
|
||||
@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -14,12 +14,21 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
@@ -77,11 +86,11 @@ class ChatConfig(BaseSettings):
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
@@ -129,6 +138,32 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
claude_agent_fallback_model: str = Field(
|
||||
default="claude-sonnet-4-20250514",
|
||||
description="Fallback model when the primary model is unavailable (e.g. 529 "
|
||||
"overloaded). The SDK automatically retries with this cheaper model.",
|
||||
)
|
||||
claude_agent_max_turns: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=500,
|
||||
description="Maximum number of agentic turns (tool-use loops) per query. "
|
||||
"Prevents runaway tool loops from burning budget.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=5.0,
|
||||
ge=0.01,
|
||||
le=100.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI aborts the "
|
||||
"request if this budget is exceeded.",
|
||||
)
|
||||
claude_agent_max_transient_retries: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
|
||||
@@ -44,12 +44,32 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
# which is retryable. Covers:
|
||||
# - Connection-level: ECONNRESET, dropped TCP connections
|
||||
# - HTTP 429: rate-limit / too-many-requests
|
||||
# - HTTP 5xx: server errors, overloaded
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
# Connection-level
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
# 429 rate-limit patterns
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"status code 429",
|
||||
# 5xx server error patterns
|
||||
"overloaded",
|
||||
"internal server error",
|
||||
"bad gateway",
|
||||
"service unavailable",
|
||||
"gateway timeout",
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
|
||||
or ``tool-outputs/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
# the project dir and the tool directory.
|
||||
# Accept both "tool-results" (SDK's persisted outputs) and
|
||||
# "tool-outputs" (the model sometimes confuses workspace paths
|
||||
# with filesystem paths and generates this variant).
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
and parts[1] in ("tool-results", "tool-outputs")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_outputs_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
|
||||
@@ -18,7 +18,13 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionInfo,
|
||||
ChatSessionMetadata,
|
||||
invalidate_session_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,6 +41,7 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
metadata: ChatSessionMetadata | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -43,6 +50,7 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
@@ -57,7 +65,12 @@ async def update_chat_session(
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
"""Update a chat session's mutable fields.
|
||||
|
||||
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
|
||||
it is set once at creation time and treated as immutable for the lifetime
|
||||
of the session.
|
||||
"""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
|
||||
@@ -251,20 +251,31 @@ class CoPilotProcessor:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
# Per-request mode override from the frontend takes priority.
|
||||
# 'fast' → baseline (OpenAI-compatible), 'extended_thinking' → SDK.
|
||||
if entry.mode == "fast":
|
||||
use_sdk = False
|
||||
elif entry.mode == "extended_thinking":
|
||||
use_sdk = True
|
||||
else:
|
||||
# No mode specified — fall back to feature flag / config.
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={entry.mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
|
||||
@@ -6,6 +6,7 @@ Defines two exchanges and queues following the graph executor pattern:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: Literal["fast", "extended_thinking"] | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: Literal["fast", "extended_thinking"] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
# GitHub user identity caches (keyed by user_id only, not provider tuple).
|
||||
# Declared here so invalidate_user_provider_cache() can reference them.
|
||||
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
|
||||
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
|
||||
)
|
||||
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
|
||||
For GitHub specifically, also clears the git-identity caches so that
|
||||
``get_github_user_git_identity()`` re-fetches the user's profile on
|
||||
the next call instead of serving stale identity data.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
if provider == "github":
|
||||
_gh_identity_cache.pop(user_id, None)
|
||||
_gh_identity_null_cache.pop(user_id, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GitHub user identity (for git committer env vars)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
|
||||
"""Fetch the GitHub user's name and email for git committer env vars.
|
||||
|
||||
Uses the ``/user`` GitHub API endpoint with the user's stored token.
|
||||
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
|
||||
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
|
||||
connected GitHub account. Returns ``None`` otherwise.
|
||||
|
||||
Results are cached for 10 minutes; "not connected" results are cached for
|
||||
60 s (same as null-token cache).
|
||||
"""
|
||||
if user_id in _gh_identity_null_cache:
|
||||
return None
|
||||
if cached := _gh_identity_cache.get(user_id):
|
||||
return cached
|
||||
|
||||
token = await get_provider_token(user_id, "github")
|
||||
if not token:
|
||||
_gh_identity_null_cache[user_id] = True
|
||||
return None
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(
|
||||
"[git-identity] GitHub /user returned %s for user %s",
|
||||
resp.status,
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
data = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
name = data.get("name") or data.get("login") or "AutoGPT User"
|
||||
# GitHub may return email=null if the user has set their email to private.
|
||||
# Fall back to the noreply address GitHub generates for every account.
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
gh_id = data.get("id", "")
|
||||
login = data.get("login", "user")
|
||||
email = f"{gh_id}+{login}@users.noreply.github.com"
|
||||
|
||||
identity = {
|
||||
"GIT_AUTHOR_NAME": name,
|
||||
"GIT_AUTHOR_EMAIL": email,
|
||||
"GIT_COMMITTER_NAME": name,
|
||||
"GIT_COMMITTER_EMAIL": email,
|
||||
}
|
||||
_gh_identity_cache[user_id] = identity
|
||||
return identity
|
||||
|
||||
@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_gh_identity_cache,
|
||||
_gh_identity_null_cache,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
@@ -49,9 +51,13 @@ def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
def test_clears_gh_identity_cache_for_github_provider(self):
|
||||
"""When provider is 'github', identity caches must also be cleared."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Old Name",
|
||||
"GIT_AUTHOR_EMAIL": "old@example.com",
|
||||
"GIT_COMMITTER_NAME": "Old Name",
|
||||
"GIT_COMMITTER_EMAIL": "old@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_cache
|
||||
|
||||
def test_clears_gh_identity_null_cache_for_github_provider(self):
|
||||
"""When provider is 'github', the identity null-cache must also be cleared."""
|
||||
_gh_identity_null_cache[_USER] = True
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_null_cache
|
||||
|
||||
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
|
||||
"""When provider is NOT 'github', identity caches must be left alone."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Some Name",
|
||||
"GIT_AUTHOR_EMAIL": "some@example.com",
|
||||
"GIT_COMMITTER_NAME": "Some Name",
|
||||
"GIT_COMMITTER_EMAIL": "some@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "some-other-provider")
|
||||
assert _USER in _gh_identity_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -129,8 +163,15 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -141,6 +182,8 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -176,6 +219,96 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -46,6 +46,16 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatSessionMetadata(BaseModel):
|
||||
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
|
||||
|
||||
Add new session-level flags here instead of adding DB columns —
|
||||
no migration required for new fields as long as a default is provided.
|
||||
"""
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -90,6 +100,12 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
@property
|
||||
def dry_run(self) -> bool:
|
||||
"""Convenience accessor for ``metadata.dry_run``."""
|
||||
return self.metadata.dry_run
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
@@ -103,6 +119,10 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Parse typed metadata from the JSON column.
|
||||
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
|
||||
metadata = ChatSessionMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
@@ -128,6 +148,7 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +156,7 @@ class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
def new(cls, user_id: str, *, dry_run: bool) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -145,6 +166,7 @@ class ChatSession(ChatSessionInfo):
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metadata=ChatSessionMetadata(dry_run=dry_run),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -532,6 +554,7 @@ async def _save_session_to_db(
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
@@ -609,21 +632,27 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID.
|
||||
dry_run: When True, run_block and run_agent tool calls in this
|
||||
session are forced to use dry-run simulation mode.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
session = ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
|
||||
@@ -46,7 +46,7 @@ messages = [
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(user_id="abc123", dry_run=False)
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
@@ -57,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -75,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
@@ -90,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
@@ -241,7 +241,7 @@ _raw_tc2 = {
|
||||
|
||||
def test_add_tool_call_appends_to_existing_assistant():
|
||||
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="working on it"),
|
||||
@@ -254,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
|
||||
|
||||
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
"""When there's no current-turn assistant, a new one is created."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
]
|
||||
@@ -267,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
|
||||
def test_add_tool_call_does_not_cross_user_boundary():
|
||||
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="old turn"),
|
||||
ChatMessage(role="user", content="new message"),
|
||||
@@ -282,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
|
||||
|
||||
def test_add_tool_call_multiple_times():
|
||||
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="doing stuff"),
|
||||
@@ -300,7 +300,7 @@ def test_add_tool_call_multiple_times():
|
||||
|
||||
def test_to_openai_messages_merges_split_assistants():
|
||||
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="build agent"),
|
||||
ChatMessage(role="assistant", content="Let me build that"),
|
||||
@@ -352,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
|
||||
@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
@@ -102,6 +103,7 @@ ToolName = Literal[
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Agent",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
|
||||
@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
|
||||
@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
### Handling binary/image data in tool outputs — CRITICAL
|
||||
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
|
||||
1. **NEVER** try to inline or render the base64 content in your response.
|
||||
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
|
||||
3. **Show** the result via the workspace download URL in Markdown: ``.
|
||||
|
||||
### Passing large data between tools — CRITICAL
|
||||
When tool outputs produce large text that you need to feed into another tool:
|
||||
- **NEVER** copy-paste the full text into the next tool call argument.
|
||||
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
|
||||
- This avoids token limits and ensures data integrity.
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
@@ -107,6 +119,28 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Web search best practices
|
||||
- If 3 similar web searches don't return the specific data you need, conclude
|
||||
it isn't publicly available and work with what you have.
|
||||
- Prefer fewer, well-targeted searches over many variations of the same query.
|
||||
- When spawning sub-agents for research, ensure each has a distinct
|
||||
non-overlapping scope to avoid redundant searches.
|
||||
|
||||
|
||||
### Tool Discovery Priority
|
||||
|
||||
When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
|
||||
|
||||
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
|
||||
|
||||
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
|
||||
|
||||
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -131,6 +165,11 @@ parent autopilot handles orchestration.
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### SDK tool-result files in E2B
|
||||
When you `Read` an SDK tool-result file, it is automatically copied into the
|
||||
sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
@@ -196,19 +235,22 @@ def _build_storage_supplement(
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
|
||||
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
|
||||
To read these files, use `Read` — it reads from the host filesystem.
|
||||
|
||||
### Large tool outputs saved to workspace
|
||||
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
|
||||
full output is in workspace storage (NOT on the local filesystem). To access it:
|
||||
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
|
||||
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
28
autogpt_platform/backend/backend/copilot/prompting_test.py
Normal file
28
autogpt_platform/backend/backend/copilot/prompting_test.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Tests for agent generation guide — verifies clarification section."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
@@ -9,11 +9,14 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +24,40 @@ logger = logging.getLogger(__name__)
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscription tier definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tiers with increasing token allowances.
|
||||
|
||||
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
|
||||
Once ``prisma generate`` is run, this can be replaced with::
|
||||
|
||||
from prisma.enums import SubscriptionTier
|
||||
"""
|
||||
|
||||
FREE = "FREE"
|
||||
PRO = "PRO"
|
||||
BUSINESS = "BUSINESS"
|
||||
ENTERPRISE = "ENTERPRISE"
|
||||
|
||||
|
||||
# Multiplier applied to the base limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole token counts and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# the type and round the result in get_global_rate_limits().
|
||||
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.FREE: 1,
|
||||
SubscriptionTier.PRO: 5,
|
||||
SubscriptionTier.BUSINESS: 20,
|
||||
SubscriptionTier.ENTERPRISE: 60,
|
||||
}
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
@@ -36,6 +73,7 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
tier: SubscriptionTier = DEFAULT_TIER
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
@@ -66,6 +104,7 @@ async def get_usage_status(
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
tier: SubscriptionTier = DEFAULT_TIER,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -74,6 +113,7 @@ async def get_usage_status(
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
tier: The user's rate-limit tier (included in the response).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -103,6 +143,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
tier=tier,
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
@@ -161,8 +202,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
Returns False if Redis is unavailable so the caller can handle
|
||||
compensation (fail-closed for billed operations, unlike the read-only
|
||||
rate-limit checks which fail-open).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
@@ -342,20 +384,100 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
async def get_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Look up the user's rate-limit tier from the database.
|
||||
|
||||
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
|
||||
to avoid a DB round-trip on every rate-limit check.
|
||||
|
||||
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
|
||||
unreachable or returns an unrecognised value, so the next call retries
|
||||
the query instead of serving a stale fallback for up to 5 minutes.
|
||||
"""
|
||||
try:
|
||||
return await _fetch_user_tier(user_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
|
||||
user_id[:8],
|
||||
DEFAULT_TIER.value,
|
||||
exc,
|
||||
)
|
||||
return DEFAULT_TIER
|
||||
|
||||
|
||||
# Expose cache management on the public function so callers (including tests)
|
||||
# never need to reach into the private ``_fetch_user_tier``.
|
||||
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
"""
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, SubscriptionTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
tier multiplier so that higher tiers receive proportionally larger
|
||||
allowances.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
@@ -377,7 +499,15 @@ async def get_global_rate_limits(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
# Apply tier multiplier
|
||||
tier = await get_user_tier(user_id)
|
||||
multiplier = TIER_MULTIPLIERS.get(tier, 1)
|
||||
if multiplier != 1:
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
return daily, weekly, tier
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
|
||||
@@ -7,12 +7,19 @@ import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
DEFAULT_TIER,
|
||||
TIER_MULTIPLIERS,
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
SubscriptionTier,
|
||||
UsageWindow,
|
||||
check_rate_limit,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
record_token_usage,
|
||||
reset_daily_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
@@ -335,6 +342,524 @@ class TestRecordTokenUsage:
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubscriptionTier and tier multipliers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubscriptionTier:
|
||||
def test_tier_values(self):
|
||||
assert SubscriptionTier.FREE.value == "FREE"
|
||||
assert SubscriptionTier.PRO.value == "PRO"
|
||||
assert SubscriptionTier.BUSINESS.value == "BUSINESS"
|
||||
assert SubscriptionTier.ENTERPRISE.value == "ENTERPRISE"
|
||||
|
||||
def test_tier_multipliers(self):
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.FREE] == 1
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.PRO] == 5
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.BUSINESS] == 20
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60
|
||||
|
||||
def test_default_tier_is_free(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.FREE
|
||||
|
||||
def test_usage_status_includes_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
|
||||
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
|
||||
)
|
||||
assert status.tier == SubscriptionTier.FREE
|
||||
|
||||
def test_usage_status_with_custom_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
|
||||
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
|
||||
tier=SubscriptionTier.PRO,
|
||||
)
|
||||
assert status.tier == SubscriptionTier.PRO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_user_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tier_from_db(self):
|
||||
"""Should return the tier stored in the user record."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_user_not_found(self):
|
||||
"""Should return DEFAULT_TIER when user is not in the DB."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_tier_is_none(self):
|
||||
"""Should return DEFAULT_TIER when subscriptionTier is None."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = None
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_db_error(self):
|
||||
"""Should fall back to DEFAULT_TIER when DB raises."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_error_is_not_cached(self):
|
||||
"""Transient DB errors should NOT cache the default tier.
|
||||
|
||||
Regression test: a transient DB failure previously cached DEFAULT_TIER
|
||||
for 5 minutes, incorrectly downgrading higher-tier users until expiry.
|
||||
"""
|
||||
failing_prisma = AsyncMock()
|
||||
failing_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=failing_prisma,
|
||||
):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Now DB recovers and returns PRO
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO now — the error result was not cached
|
||||
assert tier2 == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_invalid_tier_value(self):
|
||||
"""Should fall back to DEFAULT_TIER when stored value is invalid."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "invalid-tier"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found_is_not_cached(self):
|
||||
"""Non-existent user should NOT cache DEFAULT_TIER.
|
||||
|
||||
Regression test: when ``get_user_tier`` is called before a user record
|
||||
exists, the DEFAULT_TIER fallback must not be cached. Otherwise, a
|
||||
newly created user with a higher tier (e.g. PRO) would receive the
|
||||
stale cached FREE tier for up to 5 minutes.
|
||||
"""
|
||||
# First call: user does not exist yet
|
||||
missing_prisma = AsyncMock()
|
||||
missing_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=missing_prisma,
|
||||
):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Second call: user now exists with PRO tier
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO — the not-found result was not cached
|
||||
assert tier2 == SubscriptionTier.PRO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_user_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_db_and_invalidates_cache(self):
|
||||
"""set_user_tier should persist to DB and invalidate the tier cache."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
await set_user_tier(_USER, SubscriptionTier.PRO)
|
||||
|
||||
mock_prisma.update.assert_awaited_once_with(
|
||||
where={"id": _USER},
|
||||
data={"subscriptionTier": "PRO"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_not_found_propagates(self):
|
||||
"""RecordNotFoundError from Prisma should propagate to callers."""
|
||||
import prisma.errors
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(
|
||||
side_effect=prisma.errors.RecordNotFoundError(
|
||||
{"error": "Record not found"}
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
with pytest.raises(prisma.errors.RecordNotFoundError):
|
||||
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidated_after_set(self):
|
||||
"""After set_user_tier, get_user_tier should query DB again (not cache)."""
|
||||
# First, populate the cache with BUSINESS
|
||||
mock_user_biz = MagicMock()
|
||||
mock_user_biz.subscriptionTier = "BUSINESS"
|
||||
mock_prisma_get = AsyncMock()
|
||||
mock_prisma_get.find_unique = AsyncMock(return_value=mock_user_biz)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get,
|
||||
):
|
||||
tier_before = await get_user_tier(_USER)
|
||||
assert tier_before == SubscriptionTier.BUSINESS
|
||||
|
||||
# Now set tier to ENTERPRISE (this should invalidate the cache)
|
||||
mock_prisma_set = AsyncMock()
|
||||
mock_prisma_set.update = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_set,
|
||||
):
|
||||
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
# Now get_user_tier should hit DB again (cache was invalidated)
|
||||
mock_user_ent = MagicMock()
|
||||
mock_user_ent.subscriptionTier = "ENTERPRISE"
|
||||
mock_prisma_get2 = AsyncMock()
|
||||
mock_prisma_get2.find_unique = AsyncMock(return_value=mock_user_ent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get2,
|
||||
):
|
||||
tier_after = await get_user_tier(_USER)
|
||||
|
||||
assert tier_after == SubscriptionTier.ENTERPRISE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_global_rate_limits with tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetGlobalRateLimitsWithTiers:
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
"""Return an async side_effect that dispatches by flag_key."""
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_no_multiplier(self):
|
||||
"""Free tier should not change limits."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 2_500_000
|
||||
assert weekly == 12_500_000
|
||||
assert tier == SubscriptionTier.FREE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_tier_5x_multiplier(self):
|
||||
"""Pro tier should multiply limits by 5."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 12_500_000
|
||||
assert weekly == 62_500_000
|
||||
assert tier == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_business_tier_20x_multiplier(self):
|
||||
"""Business tier should multiply limits by 20."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 50_000_000
|
||||
assert weekly == 250_000_000
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enterprise_tier_60x_multiplier(self):
|
||||
"""Enterprise tier should multiply limits by 60."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.ENTERPRISE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 150_000_000
|
||||
assert weekly == 750_000_000
|
||||
assert tier == SubscriptionTier.ENTERPRISE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: tier limits are respected by check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTierLimitsRespected:
|
||||
"""Verify that tier-adjusted limits from get_global_rate_limits flow
|
||||
correctly into check_rate_limit, so higher tiers allow more usage and
|
||||
lower tiers are blocked when they would exceed their allocation."""
|
||||
|
||||
_BASE_DAILY = 2_500_000
|
||||
_BASE_WEEKLY = 12_500_000
|
||||
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_user_allowed_above_free_limit(self):
|
||||
"""A PRO user with usage above the FREE limit should be allowed."""
|
||||
# Usage: 3M tokens (above FREE limit of 2.5M, below PRO limit of 12.5M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["3000000", "3000000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
# PRO: 5x multiplier
|
||||
assert daily == 12_500_000
|
||||
assert tier == SubscriptionTier.PRO
|
||||
# Should NOT raise — 3M < 12.5M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_user_blocked_at_free_limit(self):
|
||||
"""A FREE user at or above the base limit should be blocked."""
|
||||
# Usage: 2.5M tokens (at FREE limit of 2.5M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["2500000", "2500000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
# FREE: 1x multiplier
|
||||
assert daily == 2_500_000
|
||||
assert tier == SubscriptionTier.FREE
|
||||
# Should raise — 2.5M >= 2.5M
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enterprise_user_has_highest_headroom(self):
|
||||
"""An ENTERPRISE user should have 60x the base limit."""
|
||||
# Usage: 100M tokens (huge, but below ENTERPRISE daily of 150M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100000000", "100000000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.ENTERPRISE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert daily == 150_000_000
|
||||
assert tier == SubscriptionTier.ENTERPRISE
|
||||
# Should NOT raise — 100M < 150M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_daily_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -421,3 +946,267 @@ class TestResetDailyUsage:
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier-limit enforcement (integration-style)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTierLimitsEnforced:
|
||||
"""Verify that tier-multiplied limits are actually respected by
|
||||
``check_rate_limit`` — i.e. that usage within the tier allowance passes
|
||||
and usage at/above the tier allowance is rejected."""
|
||||
|
||||
_BASE_DAILY = 1_000_000
|
||||
_BASE_WEEKLY = 5_000_000
|
||||
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
"""Mock LD flag lookup returning the given raw limits."""
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_within_limit_allowed(self):
|
||||
"""Usage under PRO daily limit should not raise."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
# Simulate usage just under the PRO daily limit
|
||||
mock_redis.get = AsyncMock(side_effect=[str(pro_daily - 1), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.PRO
|
||||
assert daily == pro_daily
|
||||
# Should not raise — usage is under the limit
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_at_limit_rejected(self):
|
||||
"""Usage at exactly the PRO daily limit should raise."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(pro_daily), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_business_higher_limit_allows_pro_overflow(self):
|
||||
"""Usage exceeding PRO but under BUSINESS should pass for BUSINESS."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
|
||||
# Usage between PRO and BUSINESS limits
|
||||
usage = pro_daily + 1_000_000
|
||||
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
assert daily == biz_daily
|
||||
# Should not raise — BUSINESS tier can handle this
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weekly_limit_enforced_for_tier(self):
|
||||
"""Weekly limit should also be tier-multiplied and enforced."""
|
||||
pro_weekly = self._BASE_WEEKLY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
# Daily usage fine, weekly at limit
|
||||
mock_redis.get = AsyncMock(side_effect=["0", str(pro_weekly)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_base_limit_enforced(self):
|
||||
"""Free tier (1x multiplier) should enforce the base limit exactly."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(self._BASE_DAILY), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert daily == self._BASE_DAILY # 1x multiplier
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_cannot_bypass_pro_limit(self):
|
||||
"""A FREE-tier user whose usage is within PRO limits but over FREE
|
||||
limits must still be rejected.
|
||||
|
||||
Negative test: ensures the tier multiplier is applied *before* the
|
||||
rate-limit check, so a lower-tier user cannot 'bypass' limits that
|
||||
would be acceptable for a higher tier.
|
||||
"""
|
||||
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
# Usage above FREE limit but below PRO limit
|
||||
usage = free_daily + 500_000
|
||||
assert usage < pro_daily, "test sanity: usage must be under PRO limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.FREE
|
||||
assert daily == free_daily # 1x, not 5x
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_change_updates_effective_limits(self):
|
||||
"""After upgrading from FREE to BUSINESS, the effective limits must
|
||||
increase accordingly.
|
||||
|
||||
Verifies that the tier multiplier is correctly applied after a tier
|
||||
change, and that usage that was over the FREE limit is within the new
|
||||
BUSINESS limit.
|
||||
"""
|
||||
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
|
||||
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
|
||||
# Usage above FREE limit but below BUSINESS limit
|
||||
usage = free_daily + 500_000
|
||||
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
# Simulate the user having been upgraded to BUSINESS
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
assert daily == biz_daily # 20x
|
||||
# Should NOT raise — usage is within the BUSINESS tier allowance
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_rate_limits(
|
||||
daily: int = 2_500_000,
|
||||
weekly: int = 12_500_000,
|
||||
tier: SubscriptionTier = SubscriptionTier.PRO,
|
||||
):
|
||||
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
|
||||
return patch(
|
||||
f"{_MODULE}.get_global_rate_limits",
|
||||
AsyncMock(return_value=(daily, weekly, tier)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
@@ -70,6 +82,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(daily=0),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
@@ -83,6 +96,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -112,6 +126,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -141,6 +156,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -171,6 +187,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -208,6 +225,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -228,6 +246,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -245,6 +264,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -275,6 +295,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
|
||||
@@ -3,26 +3,62 @@
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Clarifying — Before or During Building
|
||||
|
||||
Use `ask_question` whenever the user's intent is ambiguous — whether
|
||||
that's before starting or midway through the workflow. Common moments:
|
||||
|
||||
- **Before building**: output format, delivery channel, data source, or
|
||||
trigger is unspecified.
|
||||
- **During block discovery**: multiple blocks could fit and the user
|
||||
should choose.
|
||||
- **During JSON generation**: a wiring decision depends on user
|
||||
preference.
|
||||
|
||||
Steps:
|
||||
1. Call `find_block` (or another discovery tool) to learn what the
|
||||
platform actually supports for the ambiguous dimension.
|
||||
2. Call `ask_question` with a concrete question listing the discovered
|
||||
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
|
||||
which should the agent use for delivery?").
|
||||
3. **Wait for the user's answer** before continuing.
|
||||
|
||||
**Skip this** when the goal already specifies all dimensions (e.g.
|
||||
"scrape prices from Amazon and email me daily").
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
|
||||
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
|
||||
returns the full graph structure (nodes + links). **Never edit blindly** —
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
- Use block IDs from step 2 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
- When editing, apply targeted changes and preserve unchanged parts
|
||||
5. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
|
||||
`wait_for_result=120` to verify the agent works end-to-end.
|
||||
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
|
||||
found, call `edit_agent` to fix and dry-run again. Repeat until the
|
||||
simulation passes or the problems are clearly unfixable.
|
||||
See "REQUIRED: Dry-Run Verification Loop" section below for details.
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
@@ -74,8 +110,8 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
@@ -216,19 +252,62 @@ call in a loop until the task is complete:
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
After creating or editing an agent, you MUST dry-run it before telling the
|
||||
user the agent is ready. NEVER skip this step.
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
#### Step-by-step workflow
|
||||
|
||||
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
|
||||
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
|
||||
and realistic sample inputs that exercise every path in the agent. This
|
||||
simulates execution using an LLM for each block — no real API calls,
|
||||
credentials, or credits are consumed.
|
||||
3. **Inspect output**: Examine the dry-run result for problems. If
|
||||
`wait_for_result` returns only a summary, call
|
||||
`view_agent_output(execution_id=..., show_execution_details=True)` to
|
||||
see the full node-by-node execution trace. Look for:
|
||||
- **Errors / failed nodes** — a node raised an exception or returned an
|
||||
error status. Common causes: wrong `source_name`/`sink_name` in links,
|
||||
missing `input_default` values, or referencing a nonexistent block output.
|
||||
- **Null / empty outputs** — data did not flow through a link. Verify that
|
||||
`source_name` and `sink_name` match the block schemas exactly (case-
|
||||
sensitive, including nested `_#_` notation).
|
||||
- **Nodes that never executed** — the node was not reached. Likely a
|
||||
missing or broken link from an upstream node.
|
||||
- **Unexpected values** — data arrived but in the wrong type or
|
||||
structure. Check type compatibility between linked ports.
|
||||
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
|
||||
agent JSON, then go back to step 2.
|
||||
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
|
||||
or the problems are clearly unfixable. If you stop making progress,
|
||||
report the remaining issues to the user and ask for guidance.
|
||||
|
||||
#### Good vs bad dry-run output
|
||||
|
||||
**Good output** (agent is ready):
|
||||
- All nodes executed successfully (no errors in the execution trace)
|
||||
- Data flows through every link with non-null, correctly-typed values
|
||||
- The final `AgentOutputBlock` contains a meaningful result
|
||||
- Status is `COMPLETED`
|
||||
|
||||
**Bad output** (needs fixing):
|
||||
- Status is `FAILED` — check the error message for the failing node
|
||||
- An output node received `null` — trace back to find the broken link
|
||||
- A node received data in the wrong format (e.g. string where list expected)
|
||||
- Nodes downstream of a failing node were skipped entirely
|
||||
|
||||
**Special block behaviour in dry-run mode:**
|
||||
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
|
||||
orchestrator can make LLM calls and agent executors can spawn child graphs.
|
||||
Their downstream tool blocks and child-graph blocks are still simulated.
|
||||
Note: real LLM inference calls are made (consuming API quota), even though
|
||||
platform credits are not charged. Agent-mode iterations are capped at 1 in
|
||||
dry-run to keep it fast.
|
||||
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
|
||||
so the LLM can produce a realistic mock response without connecting to the
|
||||
MCP server.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user")
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
|
||||
async def _server_noop() -> None:
|
||||
"""No-op server stub — SDK tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
|
||||
)
|
||||
async def _graph_cleanup_noop() -> AsyncIterator[None]:
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
|
||||
@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -28,6 +31,12 @@ from backend.copilot.context import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default number of lines returned by ``read_file`` when the caller does not
|
||||
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
|
||||
# decide whether the model is requesting the full file (and thus whether the
|
||||
# bridge copy is worthwhile).
|
||||
_DEFAULT_READ_LIMIT = 2000
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
|
||||
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
|
||||
are handled correctly: text is encoded to bytes for the base64 shell
|
||||
pipe, and raw bytes are passed through without any encoding.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
raw = content.encode() if isinstance(content, str) else content
|
||||
encoded = base64.b64encode(raw).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
|
||||
# stay on the host. When E2B is active, also copy the file into the
|
||||
# sandbox so bash_exec can access it for further processing.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
result = _read_local(file_path, offset, limit)
|
||||
if not result.get("isError"):
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, file_path, offset, limit
|
||||
)
|
||||
if annotation:
|
||||
result["content"][0]["text"] += annotation
|
||||
return result
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Bridging: copy SDK-internal files into E2B sandbox
|
||||
|
||||
# Files larger than this are written to /home/user/ via sandbox.files.write()
|
||||
# instead of /tmp/ via shell base64, to avoid shell argument length limits
|
||||
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
|
||||
# well under the typical Linux ARG_MAX (128 KB).
|
||||
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
|
||||
# Files larger than this are skipped entirely to avoid excessive transfer times.
|
||||
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
async def bridge_to_sandbox(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
|
||||
|
||||
When the model reads an SDK-internal file (e.g. tool-results), it often
|
||||
wants to process the data with bash. Copying the file into the sandbox
|
||||
under a stable name lets ``bash_exec`` access it without extra steps.
|
||||
|
||||
Only copies when offset=0 and limit is large enough to indicate the model
|
||||
wants the full file. Errors are logged but never propagated.
|
||||
|
||||
Returns the sandbox path on success, or ``None`` on skip/failure.
|
||||
|
||||
Size handling:
|
||||
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
|
||||
(``_sandbox_write``). Kept small to stay within ARG_MAX.
|
||||
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
|
||||
``sandbox.files.write()`` to avoid shell argument length limits.
|
||||
- > 50 MB: skipped entirely with a warning.
|
||||
|
||||
The sandbox filename is prefixed with a short hash of the full source
|
||||
path to avoid collisions when different source files share the same
|
||||
basename (e.g. multiple ``result.json`` files).
|
||||
"""
|
||||
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
|
||||
return None
|
||||
try:
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
unique_name = f"{source_id}-{basename}"
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _BRIDGE_SKIP_BYTES:
|
||||
logger.warning(
|
||||
"[E2B] Skipping bridge for large file (%d bytes): %s",
|
||||
file_size,
|
||||
basename,
|
||||
)
|
||||
return None
|
||||
|
||||
def _read_bytes() -> bytes:
|
||||
with open(expanded, "rb") as fh:
|
||||
return fh.read()
|
||||
|
||||
raw_content = await asyncio.to_thread(_read_bytes)
|
||||
try:
|
||||
text_content: str | None = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text_content = None
|
||||
data: str | bytes = text_content if text_content is not None else raw_content
|
||||
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
|
||||
sandbox_path = f"/tmp/{unique_name}"
|
||||
await _sandbox_write(sandbox, sandbox_path, data)
|
||||
else:
|
||||
sandbox_path = f"/home/user/{unique_name}"
|
||||
await sandbox.files.write(sandbox_path, data)
|
||||
logger.info(
|
||||
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
|
||||
)
|
||||
return sandbox_path
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to bridge SDK file to sandbox: %s",
|
||||
file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def bridge_and_annotate(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
|
||||
|
||||
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
|
||||
callers don't need to duplicate the pattern. Returns a string like
|
||||
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
|
||||
``None`` if bridging was skipped or failed.
|
||||
"""
|
||||
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
|
||||
if sandbox_path is None:
|
||||
return None
|
||||
return f"\n[Sandbox copy available at {sandbox_path}]"
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
@@ -13,12 +14,26 @@ import pytest
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_BRIDGE_SHELL_MAX_BYTES,
|
||||
_BRIDGE_SKIP_BYTES,
|
||||
_DEFAULT_READ_LIMIT,
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
bridge_and_annotate,
|
||||
bridge_to_sandbox,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
|
||||
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
|
||||
"""Compute the expected sandbox path for a bridged file."""
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
return f"{prefix}/{source_id}-{basename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
|
||||
# (via is_allowed_local_path without sdk_cwd). Regular files live on
|
||||
# the sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -119,7 +134,7 @@ class TestReadLocal:
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
@@ -127,6 +142,25 @@ class TestReadLocal:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_tool_outputs_file(self):
|
||||
"""Reading a tool-outputs file should also succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read-outputs"
|
||||
tool_outputs_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
|
||||
)
|
||||
os.makedirs(tool_outputs_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
|
||||
with open(filepath, "w") as f:
|
||||
f.write('{"data": "test"}\n')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "test" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
@@ -335,3 +369,199 @@ class TestSandboxWrite:
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bridge_sandbox() -> SimpleNamespace:
|
||||
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
return SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
|
||||
class TestBridgeToSandbox:
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_small_file(self, tmp_path):
|
||||
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
|
||||
f = tmp_path / "result.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "result.json" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_offset_nonzero(self, tmp_path):
|
||||
"""Bridging is skipped when offset != 0 (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_limit_too_small(self, tmp_path):
|
||||
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_file_does_not_raise(self, tmp_path):
|
||||
"""Bridging a non-existent file logs but does not propagate errors."""
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(
|
||||
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_write_failure_returns_none(self, tmp_path):
|
||||
"""If sandbox write fails, returns None (best-effort)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_file_uses_files_api(self, tmp_path):
|
||||
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
|
||||
f = tmp_path / "big.json"
|
||||
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_binary_file_preserves_bytes(self, tmp_path):
|
||||
"""A small binary file is bridged to /tmp via base64 without corruption."""
|
||||
binary_data = bytes(range(256))
|
||||
f = tmp_path / "image.png"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "base64" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
|
||||
"""A large binary file is bridged to /home/user/ as raw bytes."""
|
||||
binary_data = bytes(range(256)) * 200
|
||||
f = tmp_path / "photo.jpg"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
assert call_args[1] == binary_data
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_file_skipped(self, tmp_path):
|
||||
"""Files > 50 MB are skipped entirely."""
|
||||
f = tmp_path / "huge.bin"
|
||||
# Create a sparse file to avoid actually writing 50 MB
|
||||
with open(f, "wb") as fh:
|
||||
fh.seek(_BRIDGE_SKIP_BYTES + 1)
|
||||
fh.write(b"\0")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBridgeAndAnnotate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_annotation_on_success(self, tmp_path):
|
||||
"""On success, returns a newline-prefixed annotation with the sandbox path."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected_path = _expected_bridge_path(str(f))
|
||||
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_skipped(self, tmp_path):
|
||||
"""When bridging is skipped (e.g. offset != 0), returns None."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert annotation is None
|
||||
|
||||
@@ -275,7 +275,7 @@ class TestCompactionE2E:
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
session = ChatSession.new(user_id="test-user", dry_run=False)
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
@@ -376,7 +376,7 @@ class TestCompactionE2E:
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
session = ChatSession.new(user_id="test", dry_run=False)
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
|
||||
@@ -20,6 +20,7 @@ config = ChatConfig()
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
@@ -29,25 +30,35 @@ def build_sdk_env(
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
|
||||
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
|
||||
the CLI writes temp/sub-agent output inside the per-session workspace
|
||||
directory rather than an inaccessible system temp path.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
env = {}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
@@ -65,4 +76,7 @@ def build_sdk_env(
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
|
||||
return env
|
||||
|
||||
@@ -240,3 +240,54 @@ class TestBuildSdkEnvModePriority:
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLAUDE_CODE_TMPDIR integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaudeCodeTmpdir:
|
||||
"""Verify build_sdk_env() sets CLAUDE_CODE_TMPDIR from *sdk_cwd*."""
|
||||
|
||||
def test_tmpdir_set_when_sdk_cwd_is_truthy(self):
|
||||
"""CLAUDE_CODE_TMPDIR is set to sdk_cwd when sdk_cwd is truthy."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/copilot-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/copilot-workspace"
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_none(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is None."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd=None)
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_empty_string(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is empty string."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="")
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_tmpdir_set_in_subscription_mode(self, mock_validate):
|
||||
"""CLAUDE_CODE_TMPDIR is set even in subscription mode."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/sub-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/sub-workspace"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
|
||||
@@ -28,13 +28,12 @@ Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
Always follow the **Tool Discovery Priority** described in the tool notes:
|
||||
call `find_block` before resorting to `run_mcp_tool`.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
- You searched `find_block` first and found no matching blocks, AND
|
||||
- The service is in the known hosted MCP servers list above or found via the registry API
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
"""Tests for P0 guardrails: _resolve_fallback_model, security env vars, TMPDIR."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.constants import is_transient_api_error
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_fallback_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SVC = "backend.copilot.sdk.service"
|
||||
|
||||
|
||||
class TestResolveFallbackModel:
|
||||
"""Provider-aware fallback model resolution."""
|
||||
|
||||
def test_returns_none_when_empty(self):
|
||||
cfg = _make_config(claude_agent_fallback_model="")
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
assert _resolve_fallback_model() is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4-20250514"
|
||||
assert "/" not in result
|
||||
|
||||
def test_dots_replaced_for_direct_anthropic(self):
|
||||
"""Direct Anthropic requires hyphen-separated versions."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "." not in result
|
||||
assert result == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self):
|
||||
"""OpenRouter uses dot-separated versions — don't normalise."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4.5-20250514"
|
||||
|
||||
def test_default_value(self):
|
||||
"""Default fallback model resolves to a valid string."""
|
||||
cfg = _make_config()
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "sonnet" in result.lower() or "claude" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security & isolation env vars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurityEnvVars:
|
||||
"""Verify the env-var contract in the service module.
|
||||
|
||||
The production code sets CLAUDE_CODE_TMPDIR and security env vars
|
||||
inline after ``build_sdk_env()`` returns. We grep for these string
|
||||
literals in ``service.py`` to ensure they aren't accidentally removed.
|
||||
"""
|
||||
|
||||
_SERVICE_PATH = "autogpt_platform/backend/backend/copilot/sdk/service.py"
|
||||
|
||||
@staticmethod
|
||||
def _read_service_source() -> str:
|
||||
import pathlib
|
||||
|
||||
# Walk up from this test file to the repo root
|
||||
repo = pathlib.Path(__file__).resolve().parents[5]
|
||||
return (repo / TestSecurityEnvVars._SERVICE_PATH).read_text()
|
||||
|
||||
def test_tmpdir_env_var_present_in_source(self):
|
||||
"""CLAUDE_CODE_TMPDIR must be set when sdk_cwd is provided."""
|
||||
src = self._read_service_source()
|
||||
assert 'sdk_env["CLAUDE_CODE_TMPDIR"]' in src
|
||||
|
||||
def test_home_not_overridden_in_source(self):
|
||||
"""HOME must NOT be overridden — would break git/ssh/npm."""
|
||||
src = self._read_service_source()
|
||||
assert 'sdk_env["HOME"]' not in src
|
||||
|
||||
def test_security_env_vars_present_in_source(self):
|
||||
"""All four security env vars must be set in the service module."""
|
||||
src = self._read_service_source()
|
||||
for var in (
|
||||
"CLAUDE_CODE_DISABLE_CLAUDE_MDS",
|
||||
"CLAUDE_CODE_SKIP_PROMPT_HISTORY",
|
||||
"CLAUDE_CODE_DISABLE_AUTO_MEMORY",
|
||||
"CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC",
|
||||
):
|
||||
assert var in src, f"{var} not found in service.py"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigDefaults:
|
||||
"""Verify ChatConfig P0 fields have correct defaults."""
|
||||
|
||||
def test_fallback_model_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_fallback_model
|
||||
assert "sonnet" in cfg.claude_agent_fallback_model.lower()
|
||||
|
||||
def test_max_turns_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_turns == 50
|
||||
|
||||
def test_max_budget_usd_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_budget_usd == 5.0
|
||||
|
||||
def test_max_transient_retries_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_transient_retries == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_sdk_env — all 3 auth modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ENV = "backend.copilot.sdk.env"
|
||||
|
||||
|
||||
class TestBuildSdkEnv:
|
||||
"""Verify build_sdk_env returns correct dicts for each auth mode."""
|
||||
|
||||
def test_subscription_mode_clears_keys(self):
|
||||
"""Mode 1: subscription clears API key / auth token / base URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="s1", user_id="u1")
|
||||
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == ""
|
||||
assert env["ANTHROPIC_BASE_URL"] == ""
|
||||
|
||||
def test_direct_anthropic_returns_empty_dict(self):
|
||||
"""Mode 2: direct Anthropic returns {} (inherits from parent env)."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert env == {}
|
||||
|
||||
def test_openrouter_sets_base_url_and_auth(self):
|
||||
"""Mode 3: OpenRouter sets base URL, auth token, and clears API key."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="sess-1", user_id="user-1")
|
||||
|
||||
assert env["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test"
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert "x-session-id: sess-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-user-id: user-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_openrouter_no_headers_when_ids_empty(self):
|
||||
"""Mode 3: No custom headers when session_id/user_id are not given."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in env
|
||||
|
||||
def test_all_modes_return_mutable_dict(self):
|
||||
"""build_sdk_env must return a mutable dict (not None) so callers
|
||||
can add security env vars like CLAUDE_CODE_TMPDIR."""
|
||||
for cfg in (
|
||||
_make_config(use_claude_code_subscription=True),
|
||||
_make_config(use_openrouter=False),
|
||||
_make_config(
|
||||
use_openrouter=True,
|
||||
api_key="k",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
),
|
||||
):
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert isinstance(env, dict)
|
||||
env["CLAUDE_CODE_TMPDIR"] = "/tmp/test"
|
||||
assert env["CLAUDE_CODE_TMPDIR"] == "/tmp/test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_transient_api_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTransientApiError:
|
||||
"""Verify that is_transient_api_error detects all transient patterns."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
],
|
||||
)
|
||||
def test_connection_level_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"rate limit exceeded",
|
||||
"rate_limit_error",
|
||||
"Too Many Requests",
|
||||
"status code 429",
|
||||
],
|
||||
)
|
||||
def test_429_rate_limit_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"API is overloaded",
|
||||
"Internal Server Error",
|
||||
"Bad Gateway",
|
||||
"Service Unavailable",
|
||||
"Gateway Timeout",
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
],
|
||||
)
|
||||
def test_5xx_server_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"invalid_api_key",
|
||||
"Authentication failed",
|
||||
"prompt is too long",
|
||||
"model not found",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_non_transient_errors(self, error_text: str):
|
||||
assert not is_transient_api_error(error_text)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_transient_api_error("SOCKET CONNECTION WAS CLOSED UNEXPECTEDLY")
|
||||
assert is_transient_api_error("econnreset")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validators for max_turns / max_budget_usd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigValidators:
|
||||
"""Verify ge/le bounds on max_turns and max_budget_usd."""
|
||||
|
||||
def test_max_turns_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=0)
|
||||
|
||||
def test_max_turns_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=-1)
|
||||
|
||||
def test_max_turns_rejects_above_500(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=501)
|
||||
|
||||
def test_max_turns_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_turns=1)
|
||||
assert cfg_low.claude_agent_max_turns == 1
|
||||
cfg_high = _make_config(claude_agent_max_turns=500)
|
||||
assert cfg_high.claude_agent_max_turns == 500
|
||||
|
||||
def test_max_budget_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=0.0)
|
||||
|
||||
def test_max_budget_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=-1.0)
|
||||
|
||||
def test_max_budget_rejects_above_100(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=100.01)
|
||||
|
||||
def test_max_budget_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_budget_usd=0.01)
|
||||
assert cfg_low.claude_agent_max_budget_usd == 0.01
|
||||
cfg_high = _make_config(claude_agent_max_budget_usd=100.0)
|
||||
assert cfg_high.claude_agent_max_budget_usd == 100.0
|
||||
|
||||
def test_max_transient_retries_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=-1)
|
||||
|
||||
def test_max_transient_retries_rejects_above_10(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=11)
|
||||
|
||||
def test_max_transient_retries_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_transient_retries=0)
|
||||
assert cfg_low.claude_agent_max_transient_retries == 0
|
||||
cfg_high = _make_config(claude_agent_max_transient_retries=10)
|
||||
assert cfg_high.claude_agent_max_transient_retries == 10
|
||||
@@ -38,7 +38,7 @@ class TestFlattenAssistantContent:
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
assert _flatten_assistant_content(blocks) == ""
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
@@ -47,19 +47,22 @@ class TestFlattenAssistantContent:
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
def test_unknown_block_type_dropped(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
# Unknown block types are dropped to prevent model mimicry
|
||||
assert "[__image__]" not in result
|
||||
assert "base64" not in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
@@ -279,7 +282,8 @@ class TestTranscriptToMessages:
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "read_file" not in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
@@ -399,7 +403,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -434,7 +438,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -458,7 +462,7 @@ class TestCompactTranscript:
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
@@ -564,11 +568,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
@@ -598,11 +602,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
|
||||
@@ -49,22 +49,22 @@ def test_format_assistant_tool_calls():
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
# Assistant with no content and tool_calls omitted produces no lines
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
assert 'Tool output: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
assert "Tool output: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
@@ -84,8 +84,8 @@ def test_format_full_conversation():
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
# tool_calls are omitted to prevent model mimicry
|
||||
assert "Tool output:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -76,6 +77,12 @@ class SDKResponseAdapter:
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
elif sdk_message.subtype == "task_progress":
|
||||
# Emit a heartbeat so publish_chunk is called during long
|
||||
# sub-agent runs. Without this, the Redis stream and meta
|
||||
# key TTLs expire during gaps where no real chunks are
|
||||
# produced (task_progress events were previously silent).
|
||||
responses.append(StreamHeartbeat())
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -28,6 +29,7 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .compaction import compaction_events
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
@@ -59,6 +61,14 @@ def test_system_non_init_emits_nothing():
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_task_progress_emits_heartbeat():
|
||||
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamHeartbeat)
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
@@ -250,13 +260,13 @@ def test_result_error_emits_error_and_finish():
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
result="Invalid API key provided",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert "Invalid API key provided" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
@@ -680,3 +690,102 @@ def test_already_resolved_tool_skipped_in_user_message():
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
|
||||
# -- _end_text_if_open before compaction -------------------------------------
|
||||
|
||||
|
||||
def test_end_text_if_open_emits_text_end_before_finish_step():
|
||||
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
|
||||
|
||||
When ``emit_end_if_ready`` fires compaction events while a text block is
|
||||
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
|
||||
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
|
||||
and raises "Received text-end for missing text part".
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a text block by processing an AssistantMessage with text
|
||||
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.has_started_text
|
||||
assert not adapter.has_ended_text
|
||||
|
||||
# Simulate what service.py does before yielding compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
combined = pre_close + list(compaction_events("Compacted transcript"))
|
||||
|
||||
text_end_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
|
||||
)
|
||||
finish_step_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
|
||||
)
|
||||
|
||||
assert text_end_idx is not None, "StreamTextEnd must be present"
|
||||
assert finish_step_idx is not None, "StreamFinishStep must be present"
|
||||
assert text_end_idx < finish_step_idx, (
|
||||
f"StreamTextEnd (idx={text_end_idx}) must precede "
|
||||
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
|
||||
"clears activeTextParts before text-end arrives"
|
||||
)
|
||||
|
||||
|
||||
def test_step_open_must_reset_after_compaction_finish_step():
|
||||
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
|
||||
|
||||
Compaction events bypass the adapter, so service.py must explicitly clear
|
||||
step_open after yielding a StreamFinishStep from compaction. Without this,
|
||||
the next AssistantMessage skips StreamStartStep because the adapter still
|
||||
thinks a step is open.
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a step + text block via an AssistantMessage
|
||||
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.step_open is True
|
||||
|
||||
# Simulate what service.py does: close text, then check compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
|
||||
events = list(compaction_events("Compacted transcript"))
|
||||
if any(isinstance(ev, StreamFinishStep) for ev in events):
|
||||
adapter.step_open = False
|
||||
|
||||
assert (
|
||||
adapter.step_open is False
|
||||
), "step_open must be False after compaction emits StreamFinishStep"
|
||||
|
||||
# Next AssistantMessage must open a new step
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
|
||||
results = adapter.convert_message(msg2)
|
||||
assert any(
|
||||
isinstance(r, StreamStartStep) for r in results
|
||||
), "A new StreamStartStep must be emitted after compaction closed the step"
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_when_no_text_open():
|
||||
"""_end_text_if_open emits nothing when no text block is open."""
|
||||
adapter = _adapter()
|
||||
results: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(results)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
"""_end_text_if_open emits nothing when the text block is already closed."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
# Close it once
|
||||
first: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(first)
|
||||
assert len(first) == 1
|
||||
assert isinstance(first[0], StreamTextEnd)
|
||||
# Second call must be a no-op
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestScenarioCompactAndRetry:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -170,7 +170,7 @@ class TestScenarioCompactFailsFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
@@ -261,7 +261,7 @@ class TestScenarioDoubleFailDBFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -337,7 +337,7 @@ class TestScenarioCompactionIdentical:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -730,7 +730,7 @@ class TestRetryEdgeCases:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -841,7 +841,7 @@ class TestRetryStateReset:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
@@ -904,14 +904,14 @@ class TestTranscriptEdgeCases:
|
||||
assert restored[1]["content"] == "Second"
|
||||
|
||||
def test_flatten_assistant_with_only_tool_use(self):
|
||||
"""Assistant message with only tool_use blocks (no text)."""
|
||||
"""Assistant message with only tool_use blocks (no text) flattens to empty."""
|
||||
blocks = [
|
||||
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
|
||||
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "[tool_use: bash]" in result
|
||||
assert "[tool_use: read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert result == ""
|
||||
|
||||
def test_flatten_tool_result_nested_image(self):
|
||||
"""Tool result containing image blocks uses placeholder."""
|
||||
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value={})),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
@@ -1414,3 +1414,261 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# Verify user-friendly message (not raw SDK text)
|
||||
assert "Authentication" in errors[0].errorText
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_prompt_too_long_triggers_compaction(self):
|
||||
"""CLI returns ResultMessage(subtype="error") with "Prompt is too long".
|
||||
|
||||
When the Claude CLI rejects the prompt pre-API (model=<synthetic>,
|
||||
duration_api_ms=0), it sends a ResultMessage with is_error=True
|
||||
instead of raising a Python exception. The retry loop must still
|
||||
detect this as a context-length error and trigger compaction.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="error",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] == 1:
|
||||
# First attempt: CLI returns error ResultMessage
|
||||
return self._make_client_mock(result_message=error_result)
|
||||
# Second attempt (after compaction): succeeds
|
||||
return self._make_client_mock(result_message=success_result)
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (CLI error ResultMessage "
|
||||
f"should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
|
||||
|
||||
The SDK internally compacts but the transcript is still too long. It
|
||||
returns subtype="success" (process completed) with result="Prompt is
|
||||
too long" (the actual rejection message). The retry loop must detect
|
||||
this as a context-length error and trigger compaction — the subtype
|
||||
"success" must not fool it into treating this as a real response.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="success",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
yield error_result
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
|
||||
f"result should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
|
||||
|
||||
The SDK returns error type "invalid_request" but puts the actual
|
||||
rejection message ("Prompt is too long") in the content blocks.
|
||||
The retry loop must detect this via content inspection (sdk_error
|
||||
being set confirms it's an error message, not user content).
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
# SDK returns invalid_request with "Prompt is too long" in content.
|
||||
# ResultMessage.result is a non-PTL value ("done") to isolate
|
||||
# the AssistantMessage content detection path exclusively.
|
||||
yield AssistantMessage(
|
||||
content=[TextBlock(text="Prompt is too long")],
|
||||
model="<synthetic>",
|
||||
error="invalid_request",
|
||||
)
|
||||
yield ResultMessage(
|
||||
subtype="success",
|
||||
result="done",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
|
||||
f"too long' should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -105,6 +105,10 @@ def test_agent_options_accepts_all_our_fields():
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
"stderr",
|
||||
"fallback_model",
|
||||
"max_turns",
|
||||
"max_budget_usd",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
|
||||
@@ -22,6 +22,38 @@ from .tool_adapter import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The SDK CLI uses "Task" in older versions and "Agent" in v2.x+.
|
||||
# Shared across all sessions — used by security hooks for sub-agent detection.
|
||||
_SUBAGENT_TOOLS: frozenset[str] = frozenset({"Task", "Agent"})
|
||||
|
||||
# Unicode ranges stripped by _sanitize():
|
||||
# - BiDi overrides (U+202A-U+202E, U+2066-U+2069) can trick reviewers
|
||||
# into misreading code/logs.
|
||||
# - Zero-width characters (U+200B-U+200F, U+FEFF) can hide content.
|
||||
_BIDI_AND_ZW_CHARS = set(
|
||||
chr(c)
|
||||
for r in (range(0x202A, 0x202F), range(0x2066, 0x206A), range(0x200B, 0x2010))
|
||||
for c in r
|
||||
) | {"\ufeff"}
|
||||
|
||||
|
||||
def _sanitize(value: str, max_len: int = 200) -> str:
|
||||
"""Strip control characters and truncate for safe logging.
|
||||
|
||||
Removes C0 (U+0000-U+001F), DEL (U+007F), C1 (U+0080-U+009F),
|
||||
Unicode BiDi overrides, and zero-width characters to prevent
|
||||
log injection and visual spoofing.
|
||||
"""
|
||||
cleaned = "".join(
|
||||
c
|
||||
for c in value
|
||||
if c >= " "
|
||||
and c != "\x7f"
|
||||
and not ("\x80" <= c <= "\x9f")
|
||||
and c not in _BIDI_AND_ZW_CHARS
|
||||
)
|
||||
return cleaned[:max_len]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
@@ -136,11 +168,13 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- SubagentStart: Log sub-agent lifecycle start
|
||||
- SubagentStop: Log sub-agent lifecycle end
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
max_subtasks: Maximum concurrent sub-agent spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
@@ -151,9 +185,19 @@ def create_security_hooks(
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session tracking for Task sub-agent concurrency.
|
||||
# Per-session tracking for sub-agent concurrency.
|
||||
# Set of tool_use_ids that consumed a slot — len() is the active count.
|
||||
task_tool_use_ids: set[str] = set()
|
||||
#
|
||||
# LIMITATION: For background (async) agents the SDK returns the
|
||||
# Agent/Task tool immediately with {isAsync: true}, which triggers
|
||||
# PostToolUse and releases the slot while the agent is still running.
|
||||
# SubagentStop fires later when the background process finishes but
|
||||
# does not currently hold a slot. This means the concurrency limit
|
||||
# only gates *launches*, not true concurrent execution. To fix this
|
||||
# we would need to track background agent_ids separately and release
|
||||
# in SubagentStop, but the SDK does not guarantee SubagentStop fires
|
||||
# for every background agent (e.g. on session abort).
|
||||
subagent_tool_use_ids: set[str] = set()
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
@@ -165,29 +209,22 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
# Rate-limit sub-agent spawns per session.
|
||||
# The SDK CLI renamed "Task" → "Agent" in v2.x; handle both.
|
||||
if tool_name in _SUBAGENT_TOOLS:
|
||||
# Background agents are allowed — the SDK returns immediately
|
||||
# with {isAsync: true} and the model polls via TaskOutput.
|
||||
# Still count them against the concurrency limit.
|
||||
if len(subagent_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
f"[SDK] Sub-agent limit reached ({max_subtasks}), "
|
||||
f"user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
f"Maximum {max_subtasks} concurrent sub-agents. "
|
||||
"Wait for running sub-agents to finish, "
|
||||
"or continue in the main conversation."
|
||||
),
|
||||
)
|
||||
@@ -208,20 +245,20 @@ def create_security_hooks(
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Reserve the Task slot only after all validations pass
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
# Reserve the sub-agent slot only after all validations pass
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id is not None:
|
||||
subagent_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a Task concurrency slot if one was reserved."""
|
||||
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
|
||||
task_tool_use_ids.discard(tool_use_id)
|
||||
def _release_subagent_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a sub-agent concurrency slot if one was reserved."""
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id in subagent_tool_use_ids:
|
||||
subagent_tool_use_ids.discard(tool_use_id)
|
||||
logger.info(
|
||||
"[SDK] Task slot released, active=%d/%d, user=%s",
|
||||
len(task_tool_use_ids),
|
||||
"[SDK] Sub-agent slot released, active=%d/%d, user=%s",
|
||||
len(subagent_tool_use_ids),
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
@@ -241,13 +278,14 @@ def create_security_hooks(
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""), max_len=12)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
@@ -256,7 +294,7 @@ def create_security_hooks(
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
resp_preview = _sanitize(str(tool_response), max_len=100)
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
@@ -280,13 +318,17 @@ def create_security_hooks(
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
error = _sanitize(str(input_data.get("error", "Unknown error")))
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""))
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
error,
|
||||
user_id,
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
@@ -301,20 +343,17 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
transcript_path = _sanitize(
|
||||
str(input_data.get("transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
@@ -323,6 +362,44 @@ def create_security_hooks(
|
||||
on_compact(transcript_path)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_start_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent starts execution."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
logger.info(
|
||||
"[SDK] SubagentStart: agent_id=%s, type=%s, user=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent stops."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
transcript = _sanitize(
|
||||
str(input_data.get("agent_transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] SubagentStop: agent_id=%s, type=%s, user=%s, transcript=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
transcript,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -330,6 +407,8 @@ def create_security_hooks(
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
"SubagentStart": [HookMatcher(matcher="*", hooks=[subagent_start_hook])],
|
||||
"SubagentStop": [HookMatcher(matcher="*", hooks=[subagent_stop_hook])],
|
||||
}
|
||||
|
||||
return hooks
|
||||
|
||||
@@ -5,13 +5,18 @@ They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .security_hooks import (
|
||||
_validate_tool_access,
|
||||
_validate_user_isolation,
|
||||
create_security_hooks,
|
||||
)
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
@@ -132,8 +137,20 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_tool_outputs_allowed():
|
||||
"""tool-outputs/ paths should be allowed, same as tool-results/."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-outputs/12345.txt"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/tool-outputs is."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
@@ -220,8 +237,6 @@ def test_bash_builtin_blocked_message_clarity():
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return (pre, post, post_failure) handlers."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
@@ -231,16 +246,15 @@ def _hooks():
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
async def test_task_background_allowed(_hooks):
|
||||
"""Task with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
tool_use_id="tu-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@@ -354,3 +368,303 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# "Agent" tool name (SDK v2.x+ renamed "Task" → "Agent")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_background_allowed(_hooks):
|
||||
"""Agent with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "x"},
|
||||
},
|
||||
tool_use_id="tu-agent-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_foreground_allowed(_hooks):
|
||||
"""Agent without run_in_background should be allowed."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id="tu-agent-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_agent_counts_against_limit(_hooks):
|
||||
"""Background agents still consume concurrency slots."""
|
||||
pre, _, _ = _hooks
|
||||
# Two background agents fill the limit
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "bg"},
|
||||
},
|
||||
tool_use_id=f"tu-bglimit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
# Third (background or foreground) should be denied
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "over"},
|
||||
},
|
||||
tool_use_id="tu-bglimit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_limit_enforced(_hooks):
|
||||
"""Agent spawns beyond max_subtasks should be denied."""
|
||||
pre, _, _ = _hooks
|
||||
# First two should pass
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-limit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id="tu-agent-limit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_completion(_hooks):
|
||||
"""Completing an Agent should free a slot so new Agents can be spawned."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-comp-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied — at capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-comp-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Complete first agent — frees a slot
|
||||
await post(
|
||||
{"tool_name": "Agent", "tool_input": {}},
|
||||
tool_use_id="tu-agent-comp-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now a new Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after release"}},
|
||||
tool_use_id="tu-agent-comp-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_failure(_hooks):
|
||||
"""A failed Agent should also free its concurrency slot."""
|
||||
pre, _, post_failure = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-fail-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# At capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-fail-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Fail first agent — should free a slot
|
||||
await post_failure(
|
||||
{"tool_name": "Agent", "tool_input": {}, "error": "something broke"},
|
||||
tool_use_id="tu-agent-fail-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# New Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after failure"}},
|
||||
tool_use_id="tu-agent-fail-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_task_agent_share_slots(_hooks):
|
||||
"""Task and Agent share the same concurrency pool."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill one slot with Task, one with Agent
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-agent",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third (either name) should be denied
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-mix-over",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Release the Task slot
|
||||
await post(
|
||||
{"tool_name": "Task", "tool_input": {}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now an Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after task release"}},
|
||||
tool_use_id="tu-mix-new",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStart / SubagentStop hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _subagent_hooks():
|
||||
"""Create hooks and return (subagent_start, subagent_stop) handlers."""
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
start = hooks["SubagentStart"][0].hooks[0]
|
||||
stop = hooks["SubagentStop"][0].hooks[0]
|
||||
return start, stop
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_start_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStart hook should return an empty dict (logging only)."""
|
||||
start, _ = _subagent_hooks
|
||||
result = await start(
|
||||
{"agent_id": "sa-123", "agent_type": "research"},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_stop_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStop hook should return an empty dict (logging only)."""
|
||||
_, stop = _subagent_hooks
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa-123",
|
||||
"agent_type": "research",
|
||||
"agent_transcript_path": "/tmp/transcript.txt",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
|
||||
"""SubagentStart/Stop should sanitize control chars from inputs."""
|
||||
start, stop = _subagent_hooks
|
||||
# Inject control characters (C0, DEL, C1, BiDi overrides, zero-width)
|
||||
# — hook should not raise AND logs must be clean
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await start(
|
||||
{
|
||||
"agent_id": "sa\n-injected\r\x00\x7f",
|
||||
"agent_type": "safe\x80_type\x9f\ttab",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
# Control chars must be stripped from the logged values
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\x80" not in record.message
|
||||
assert "\x9f" not in record.message
|
||||
assert "safe_type" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa\n-injected\x7f",
|
||||
"agent_type": "type\r\x80\x9f",
|
||||
"agent_transcript_path": "/tmp/\x00malicious\npath\u202a\u200b",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\u202a" not in record.message
|
||||
assert "\u200b" not in record.message
|
||||
assert "/tmp/maliciouspath" in caplog.text
|
||||
|
||||
@@ -59,11 +59,14 @@ from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
@@ -81,11 +84,9 @@ from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
reset_tool_failure_counters,
|
||||
set_execution_context,
|
||||
@@ -115,9 +116,10 @@ _MAX_STREAM_ATTEMPTS = 3
|
||||
|
||||
# Hard circuit breaker: abort the stream if the model sends this many
|
||||
# consecutive tool calls with empty parameters (a sign of context
|
||||
# saturation or serialization failure). Empty input ({}) is never
|
||||
# legitimate — even one is suspicious, three is conclusive.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 3
|
||||
# saturation or serialization failure). The MCP wrapper now returns
|
||||
# guidance on the first empty call, giving the model a chance to
|
||||
# self-correct. The limit is generous to allow recovery attempts.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 5
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
@@ -746,15 +748,11 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
elif msg.role == "assistant":
|
||||
if msg.content:
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_name = func.get("name", "unknown")
|
||||
tool_args = func.get("arguments", "")
|
||||
lines.append(f"You called tool: {tool_name}({tool_args})")
|
||||
# Omit tool_calls — any text representation gets mimicked
|
||||
# by the model. Tool results below provide the context.
|
||||
elif msg.role == "tool":
|
||||
content = msg.content or ""
|
||||
lines.append(f"Tool result: {content}")
|
||||
lines.append(f"Tool output: {content[:500]}")
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
@@ -1214,6 +1212,14 @@ async def _run_stream_attempt(
|
||||
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
# --- Intermediate persistence tracking ---
|
||||
# Flush session messages to DB periodically so page reloads show progress
|
||||
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
|
||||
_last_flush_time = time.monotonic()
|
||||
_msgs_since_flush = 0
|
||||
_FLUSH_INTERVAL_SECONDS = 30.0
|
||||
_FLUSH_MESSAGE_THRESHOLD = 10
|
||||
|
||||
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
|
||||
# suppress SDK cleanup errors that occur when the SSE client disconnects
|
||||
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
|
||||
@@ -1300,6 +1306,27 @@ async def _run_stream_attempt(
|
||||
error_preview,
|
||||
)
|
||||
|
||||
# Intercept prompt-too-long errors surfaced as
|
||||
# AssistantMessage.error (not as a Python exception).
|
||||
# Re-raise so the outer retry loop can compact the
|
||||
# transcript and retry with reduced context.
|
||||
# Check both error_text and error_preview: sdk_error
|
||||
# being set confirms this is an error message (not user
|
||||
# content), so checking content is safe. The actual
|
||||
# error description (e.g. "Prompt is too long") may be
|
||||
# in the content, not the error type field
|
||||
# (e.g. error="invalid_request", content="Prompt is
|
||||
# too long").
|
||||
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
|
||||
Exception(error_preview)
|
||||
):
|
||||
logger.warning(
|
||||
"%s Prompt-too-long detected via AssistantMessage "
|
||||
"error — raising for retry",
|
||||
ctx.log_prefix,
|
||||
)
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Intercept transient API errors (socket closed,
|
||||
# ECONNRESET) — replace the raw message with a
|
||||
# user-friendly error text and use the retryable
|
||||
@@ -1327,28 +1354,17 @@ async def _run_stream_attempt(
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# Parallel tool execution: pre-launch every ToolUseBlock as an
|
||||
# asyncio.Task the moment its AssistantMessage arrives. The SDK
|
||||
# sends one AssistantMessage per tool call when issuing parallel
|
||||
# calls, so each message is pre-launched independently. The MCP
|
||||
# handlers will await the already-running task instead of executing
|
||||
# fresh, making all concurrent tool calls run in parallel.
|
||||
#
|
||||
# Also determine if the message is a tool-only batch (all content
|
||||
# Determine if the message is a tool-only batch (all content
|
||||
# items are ToolUseBlocks) — such messages have no text output yet,
|
||||
# so we skip the wait_for_stash flush below.
|
||||
#
|
||||
# Note: parallel execution of tools is handled natively by the
|
||||
# SDK CLI via readOnlyHint annotations on tool definitions.
|
||||
is_tool_only = False
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
|
||||
is_tool_only = True
|
||||
# NOTE: Pre-launches are sequential (each await completes
|
||||
# file-ref expansion before the next starts). This is fine
|
||||
# since expansion is typically sub-ms; a future optimisation
|
||||
# could gather all pre-launches concurrently.
|
||||
for tool_use in sdk_msg.content:
|
||||
if isinstance(tool_use, ToolUseBlock):
|
||||
await pre_launch_tool_call(tool_use.name, tool_use.input)
|
||||
else:
|
||||
is_tool_only = False
|
||||
is_tool_only = all(
|
||||
isinstance(item, ToolUseBlock) for item in sdk_msg.content
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
@@ -1405,6 +1421,16 @@ async def _run_stream_attempt(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Check for prompt-too-long regardless of subtype — the
|
||||
# SDK may return subtype="success" with result="Prompt is
|
||||
# too long" when the CLI rejects the prompt before calling
|
||||
# the API (cost_usd=0, no tokens consumed). If we only
|
||||
# check the "error" subtype path, the stream appears to
|
||||
# complete normally, the synthetic error text is stored
|
||||
# in the transcript, and the session grows without bound.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
@@ -1436,6 +1462,23 @@ async def _run_stream_attempt(
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
|
||||
if compact_result.events:
|
||||
# Compaction events end with StreamFinishStep, which maps to
|
||||
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
|
||||
# Close any open text block BEFORE the compaction events so
|
||||
# the text-end arrives before finish-step, preventing
|
||||
# "text-end for missing text part" errors on the frontend.
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(pre_close)
|
||||
# Compaction events bypass the adapter, so sync step state
|
||||
# when a StreamFinishStep is present — otherwise the adapter
|
||||
# will skip StreamStartStep on the next AssistantMessage.
|
||||
if any(
|
||||
isinstance(ev, StreamFinishStep) for ev in compact_result.events
|
||||
):
|
||||
state.adapter.step_open = False
|
||||
for r in pre_close:
|
||||
yield r
|
||||
for ev in compact_result.events:
|
||||
yield ev
|
||||
entries_replaced = False
|
||||
@@ -1482,6 +1525,34 @@ async def _run_stream_attempt(
|
||||
model=sdk_msg.model,
|
||||
)
|
||||
|
||||
# --- Intermediate persistence ---
|
||||
# Flush session messages to DB periodically so page reloads
|
||||
# show progress during long-running turns.
|
||||
_msgs_since_flush += 1
|
||||
now = time.monotonic()
|
||||
if (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(ctx.session))
|
||||
logger.debug(
|
||||
"%s Intermediate flush: %d messages "
|
||||
"(msgs_since=%d, elapsed=%.1fs)",
|
||||
ctx.log_prefix,
|
||||
len(ctx.session.messages),
|
||||
_msgs_since_flush,
|
||||
now - _last_flush_time,
|
||||
)
|
||||
except Exception as flush_err:
|
||||
logger.warning(
|
||||
"%s Intermediate flush failed: %s",
|
||||
ctx.log_prefix,
|
||||
flush_err,
|
||||
)
|
||||
_last_flush_time = now
|
||||
_msgs_since_flush = 0
|
||||
|
||||
if acc.stream_completed:
|
||||
break
|
||||
finally:
|
||||
@@ -1813,7 +1884,10 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
# sdk_cwd routes the CLI's temp dir into the per-session workspace
|
||||
# so sub-agent output files land inside sdk_cwd (see build_sdk_env).
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id, sdk_cwd=sdk_cwd)
|
||||
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
@@ -2008,13 +2082,22 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
try:
|
||||
async for event in _run_stream_attempt(stream_ctx, state):
|
||||
if not isinstance(event, StreamHeartbeat):
|
||||
if not isinstance(
|
||||
event,
|
||||
(
|
||||
StreamHeartbeat,
|
||||
# Compaction UI events are cosmetic and must not
|
||||
# block retry — they're emitted before the SDK
|
||||
# query on compacted attempts.
|
||||
StreamStartStep,
|
||||
StreamFinishStep,
|
||||
StreamToolInputStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
events_yielded += 1
|
||||
yield event
|
||||
# Cancel any pre-launched tasks that were never dispatched
|
||||
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
|
||||
# with the three error-path await cancel_pending_tool_tasks() calls.
|
||||
await cancel_pending_tool_tasks()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -2023,9 +2106,6 @@ async def stream_chat_completion_sdk(
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
)
|
||||
# Cancel any pre-launched tasks so they don't continue executing
|
||||
# against a rolled-back or abandoned session.
|
||||
await cancel_pending_tool_tasks()
|
||||
raise
|
||||
except _HandledStreamError as exc:
|
||||
# _run_stream_attempt already yielded a StreamError and
|
||||
@@ -2057,8 +2137,6 @@ async def stream_chat_completion_sdk(
|
||||
retryable=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
# Cancel any pre-launched tasks from the failed attempt.
|
||||
await cancel_pending_tool_tasks()
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
@@ -2075,9 +2153,6 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Cancel any pre-launched tasks from the failed attempt so they
|
||||
# don't continue executing against the rolled-back session.
|
||||
await cancel_pending_tool_tasks()
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
|
||||
from .service import (
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_safe_close_sdk_client,
|
||||
@@ -405,6 +406,49 @@ def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Tests for _normalize_model_name — shared provider-aware normalization."""
|
||||
|
||||
def test_strips_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6"
|
||||
|
||||
def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
|
||||
@@ -392,7 +392,7 @@ class TestFlattenThinkingBlocks:
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
"""Mixed blocks: only text survives flattening; thinking and tool_use dropped."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
@@ -403,7 +403,8 @@ class TestFlattenThinkingBlocks:
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -438,7 +439,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -497,7 +498,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
@@ -550,7 +551,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -600,7 +601,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -637,7 +638,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -698,7 +699,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
|
||||
@@ -14,6 +14,7 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
@@ -37,7 +38,7 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS, bridge_and_annotate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -53,14 +54,6 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -115,7 +108,6 @@ def set_execution_context(
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
@@ -132,48 +124,6 @@ def reset_stash_event() -> None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
@@ -249,10 +199,6 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -271,95 +217,13 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
"""
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
@@ -455,83 +319,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -599,7 +387,16 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return _mcp_ok("".join(selected))
|
||||
#
|
||||
# When E2B is active, also copy the file into the sandbox so
|
||||
# bash_exec can process it (the model often uses Read then bash).
|
||||
text = "".join(selected)
|
||||
sandbox = _current_sandbox.get(None)
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(sandbox, resolved, offset, limit)
|
||||
if annotation:
|
||||
text += annotation
|
||||
return _mcp_ok(text)
|
||||
except FileNotFoundError:
|
||||
return _mcp_err(f"File not found: {file_path}")
|
||||
except Exception as e:
|
||||
@@ -648,9 +445,19 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
_PARALLEL_ANNOTATION = ToolAnnotations(readOnlyHint=True)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
All tools are annotated with ``readOnlyHint=True`` so the SDK CLI
|
||||
dispatches concurrent tool calls in parallel rather than sequentially.
|
||||
This is a deliberate override: even side-effect tools use the hint
|
||||
because the MCP tools are already individually sandboxed and the
|
||||
pre-launch duplicate-execution bug (SECRT-2204) is worse than
|
||||
sequential dispatch.
|
||||
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
that route directly to the E2B sandbox filesystem, and the caller should
|
||||
disable the corresponding SDK built-in tools via
|
||||
@@ -668,6 +475,28 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Empty tool args = model's output was truncated by the API's
|
||||
# max_tokens limit. Instead of letting the tool fail with a
|
||||
# confusing error (and eventually tripping the circuit breaker),
|
||||
# return clear guidance so the model can self-correct.
|
||||
if not args and input_schema and input_schema.get("required"):
|
||||
logger.warning(
|
||||
"[MCP] %s called with empty args (likely output "
|
||||
"token truncation) — returning guidance",
|
||||
tool_name,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Your call to {tool_name} had empty arguments — "
|
||||
f"this means your previous response was too long and "
|
||||
f"the tool call input was truncated by the API. "
|
||||
f"To fix this: break your work into smaller steps. "
|
||||
f"For large content, first write it to a file using "
|
||||
f"bash_exec with cat >> (append section by section), "
|
||||
f"then pass it via @@agptfile:filename reference. "
|
||||
f"Do NOT retry with the same approach — it will "
|
||||
f"be truncated again."
|
||||
)
|
||||
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
@@ -718,24 +547,35 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
# All tools annotated readOnlyHint=True to enable parallel dispatch.
|
||||
# The SDK CLI uses this hint to dispatch concurrent tool calls in
|
||||
# parallel rather than sequentially. Side-effect safety is ensured
|
||||
# by the tool implementations themselves (idempotency, credentials).
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
if use_e2b:
|
||||
for name, desc, schema, handler in E2B_FILE_TOOLS:
|
||||
decorated = tool(name, desc, schema)(_truncating(handler, name))
|
||||
decorated = tool(
|
||||
name,
|
||||
desc,
|
||||
schema,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(handler, name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Read tool for SDK-truncated tool results (always needed).
|
||||
# Read tool for SDK-truncated tool results (always needed, read-only).
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
annotations=_PARALLEL_ANNOTATION,
|
||||
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
@@ -750,13 +590,14 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# Security hooks validate that file paths stay within sdk_cwd.
|
||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# Task/Agent allows spawning sub-agents (rate-limited by security hooks).
|
||||
# The CLI renamed "Task" → "Agent" in v2.x; both are listed for compat.
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
"""Tests for tool_adapter: truncation, stash, context vars, readOnlyHint annotations."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from mcp.types import ToolAnnotations
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
SDK_DISALLOWED_TOOLS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
@@ -244,7 +243,7 @@ class TestTruncationAndStashIntegration:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel pre-launch infrastructure
|
||||
# create_tool_handler (direct execution, no pre-launch)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -277,169 +276,18 @@ def _init_ctx(session=None):
|
||||
)
|
||||
|
||||
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
class TestCreateToolHandler:
|
||||
"""Tests for create_tool_handler — direct tool execution."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
async def test_handler_executes_tool_directly(self):
|
||||
"""Handler executes the tool and returns MCP-formatted result."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
@@ -449,104 +297,9 @@ class TestCreateToolHandlerParallel:
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
async def test_handler_returns_error_on_no_session(self):
|
||||
"""When session is None, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
@@ -555,220 +308,406 @@ class TestCreateToolHandlerParallel:
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
async def test_handler_returns_error_on_exception(self):
|
||||
"""Exception from tool execution is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
async def test_handler_executes_once_per_call(self):
|
||||
"""Each handler call executes the tool exactly once — no duplicate execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="single-execution")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1"})
|
||||
await handler({"block_id": "b2"})
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests: bugs fixed by removing pre-launch mechanism
|
||||
#
|
||||
# Each test class includes a _buggy_handler fixture that reproduces the old
|
||||
# pre-launch implementation inline. Tests run against BOTH the buggy handler
|
||||
# (xfail — proves the bug exists) and the current clean handler (must pass).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_execute_fn(tool_name: str = "run_block"):
|
||||
"""Return (execute_fn, call_log) — execute_fn records every call."""
|
||||
call_log: list[dict] = []
|
||||
|
||||
async def execute_fn(*args, **kwargs):
|
||||
call_log.append(kwargs)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{len(call_log)}",
|
||||
output=f"result-{len(call_log)}",
|
||||
toolName=tool_name,
|
||||
success=True,
|
||||
)
|
||||
|
||||
return execute_fn, call_log
|
||||
|
||||
|
||||
async def _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args):
|
||||
"""Simulate the OLD buggy pre-launch flow.
|
||||
|
||||
1. pre_launch_tool_call fires _execute_tool_sync with pre_launch_args
|
||||
2. SDK dispatches handler with dispatch_args
|
||||
3. Handler compares args — on mismatch, cancels + re-executes (BUG)
|
||||
|
||||
Returns the handler result.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
|
||||
|
||||
user_id, session = "user-1", _make_mock_session()
|
||||
|
||||
# Step 1: pre-launch fires immediately (speculative)
|
||||
task = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, pre_launch_args)
|
||||
)
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
# Step 2: SDK dispatches with (potentially different) args
|
||||
if pre_launch_args != dispatch_args:
|
||||
# Arg mismatch path: cancel pre-launched task + re-execute
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to direct execution (duplicate!)
|
||||
return await _execute_tool_sync(mock_tool, user_id, session, dispatch_args)
|
||||
else:
|
||||
return await task
|
||||
|
||||
|
||||
class TestBug1DuplicateExecution:
|
||||
"""Bug 1 (SECRT-2204): arg mismatch causes duplicate execution.
|
||||
|
||||
Pre-launch fires with raw args, SDK dispatches with normalised args.
|
||||
Mismatch → cancel (too late) + re-execute → 2 API calls.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old pre-launch code causes duplicate execution")
|
||||
@pytest.mark.asyncio
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
async def test_old_code_duplicates_on_arg_mismatch(self):
|
||||
"""OLD CODE: pre-launch with args A, dispatch with args B → 2 calls."""
|
||||
execute_fn, call_log = _make_execute_fn()
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_fn)
|
||||
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
pre_launch_args = {"block_id": "b1", "input_data": {"title": "Test"}}
|
||||
dispatch_args = {
|
||||
"block_id": "b1",
|
||||
"input_data": {"title": "Test", "priority": None},
|
||||
}
|
||||
|
||||
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
|
||||
|
||||
# BUG: pre-launch executed once + fallback executed again = 2
|
||||
assert len(call_log) == 1, (
|
||||
f"Expected 1 execution but got {len(call_log)} — "
|
||||
f"duplicate execution bug!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_no_duplicate(self):
|
||||
"""FIXED: handler executes exactly once regardless of arg shape."""
|
||||
execute_fn, call_log = _make_execute_fn()
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_fn)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1", "input_data": {"title": "Test"}})
|
||||
|
||||
assert len(call_log) == 1, f"Expected 1 execution but got {len(call_log)}"
|
||||
|
||||
|
||||
class TestBug2FIFODesync:
|
||||
"""Bug 2: FIFO desync when security hook denies a tool.
|
||||
|
||||
Pre-launch queues [task_A, task_B]. Tool A denied (no MCP dispatch).
|
||||
Tool B's handler dequeues task_A → returns wrong result.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old FIFO queue returns wrong result on denial")
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_code_fifo_desync_on_denial(self):
|
||||
"""OLD CODE: denied tool's task stays in queue, next tool gets wrong result."""
|
||||
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
|
||||
|
||||
call_log: list[str] = []
|
||||
|
||||
async def tagged_execute(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_log.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
toolCallId="id",
|
||||
output=f"result-for-{tag}",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
|
||||
user_id, session = "user-1", _make_mock_session()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
# Simulate old FIFO queue
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
# Pre-launch for tool A and tool B
|
||||
task_a = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "A"})
|
||||
)
|
||||
task_b = asyncio.create_task(
|
||||
_execute_tool_sync(mock_tool, user_id, session, {"block_id": "B"})
|
||||
)
|
||||
queue.put_nowait(task_a)
|
||||
queue.put_nowait(task_b)
|
||||
await asyncio.sleep(0) # let both tasks run
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
# Tool A is DENIED by security hook — no MCP dispatch, no dequeue
|
||||
# Tool B's handler dequeues from FIFO → gets task_A!
|
||||
dequeued_task = queue.get_nowait()
|
||||
result = await dequeued_task
|
||||
result_text = result["content"][0]["text"]
|
||||
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
# BUG: handler for B got task_A's result
|
||||
assert "result-for-B" in result_text, (
|
||||
f"Expected result for B but got: {result_text} — "
|
||||
f"FIFO desync: B got A's result!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_no_fifo_desync(self):
|
||||
"""FIXED: each handler call executes independently, no shared queue."""
|
||||
call_log: list[str] = []
|
||||
|
||||
async def tagged_execute(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_log.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"result-for-{tag}",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=tagged_execute)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
|
||||
# Tool A denied (never called). Tool B dispatched normally.
|
||||
result_b = await handler({"block_id": "B"})
|
||||
|
||||
assert "result-for-B" in result_b["content"][0]["text"]
|
||||
assert call_log == ["B"]
|
||||
|
||||
|
||||
class TestBug3CancelRace:
|
||||
"""Bug 3: cancel race — task completes before cancel arrives.
|
||||
|
||||
Pre-launch fires fast HTTP call (< 1s). By the time handler detects
|
||||
mismatch and calls task.cancel(), the API call already completed.
|
||||
Side effect (Linear issue created) is irreversible.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.xfail(reason="Old code: cancel arrives after task completes")
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_code_cancel_arrives_too_late(self):
|
||||
"""OLD CODE: fast task completes before cancel, side effect persists."""
|
||||
side_effects: list[str] = []
|
||||
|
||||
async def fast_execute_with_side_effect(*args, **kwargs):
|
||||
# Side effect happens immediately (like an HTTP POST to Linear)
|
||||
side_effects.append("created-issue")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output="issue-created",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=fast_execute_with_side_effect)
|
||||
|
||||
# Pre-launch fires immediately
|
||||
pre_launch_args = {"block_id": "b1"}
|
||||
dispatch_args = {"block_id": "b1", "extra": "normalised"}
|
||||
|
||||
await _buggy_prelaunch_handler(mock_tool, pre_launch_args, dispatch_args)
|
||||
|
||||
# BUG: side effect happened TWICE (pre-launch + fallback)
|
||||
assert len(side_effects) == 1, (
|
||||
f"Expected 1 side effect but got {len(side_effects)} — "
|
||||
f"cancel race: pre-launch completed before cancel!"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_code_single_side_effect(self):
|
||||
"""FIXED: no speculative execution, exactly 1 side effect per call."""
|
||||
side_effects: list[str] = []
|
||||
|
||||
async def execute_with_side_effect(*args, **kwargs):
|
||||
side_effects.append("created-issue")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output="issue-created",
|
||||
toolName="run_block",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_side_effect)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
assert len(side_effects) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# readOnlyHint annotations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadOnlyAnnotations:
|
||||
"""Tests that all tools get readOnlyHint=True for parallel dispatch."""
|
||||
|
||||
def test_parallel_annotation_constant(self):
|
||||
"""_PARALLEL_ANNOTATION is a ToolAnnotations with readOnlyHint=True."""
|
||||
from .tool_adapter import _PARALLEL_ANNOTATION
|
||||
|
||||
assert isinstance(_PARALLEL_ANNOTATION, ToolAnnotations)
|
||||
assert _PARALLEL_ANNOTATION.readOnlyHint is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_DISALLOWED_TOOLS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSDKDisallowedTools:
|
||||
"""Verify that dangerous SDK built-in tools are in the disallowed list."""
|
||||
|
||||
def test_bash_tool_is_disallowed(self):
|
||||
assert "Bash" in SDK_DISALLOWED_TOOLS
|
||||
|
||||
def test_webfetch_tool_is_disallowed(self):
|
||||
"""WebFetch is disallowed due to SSRF risk."""
|
||||
assert "WebFetch" in SDK_DISALLOWED_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — bridge_and_annotate integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileHandlerBridge:
|
||||
"""Verify that _read_file_handler calls bridge_and_annotate when a sandbox is active."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-bridge-test",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_called_when_sandbox_active(self, tmp_path, monkeypatch):
|
||||
"""When a sandbox is set, bridge_and_annotate is called and its annotation appended."""
|
||||
from backend.copilot.context import _current_sandbox
|
||||
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
fake_sandbox = object()
|
||||
token = _current_sandbox.set(fake_sandbox) # type: ignore[arg-type]
|
||||
try:
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 1
|
||||
assert bridge_calls[0][0] is fake_sandbox
|
||||
assert "/tmp/abc-data.json" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_sandbox.reset(token)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_not_called_without_sandbox(self, tmp_path, monkeypatch):
|
||||
"""When no sandbox is set, bridge_and_annotate is not called."""
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 0
|
||||
assert "Sandbox copy" not in result["content"][0]["text"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,235 +1,10 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
"""Re-export from shared ``backend.copilot.transcript_builder`` for backward compat.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
The canonical implementation now lives at ``backend.copilot.transcript_builder``
|
||||
so both the SDK and baseline paths can import without cross-package
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
__all__ = ["TranscriptBuilder", "TranscriptEntry"]
|
||||
|
||||
@@ -13,6 +13,7 @@ from .transcript import (
|
||||
delete_transcript,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
@@ -302,7 +303,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -322,7 +323,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -340,7 +341,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -857,11 +858,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
@@ -893,11 +894,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
@@ -931,11 +932,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
@@ -969,19 +970,19 @@ class TestRunCompression:
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
@@ -1014,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1043,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1069,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1095,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1117,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1136,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1164,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1188,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1200,3 +1201,170 @@ class TestCleanupStaleProjectDirs:
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# strip_stale_thinking_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStripStaleThinkingBlocks:
|
||||
"""Tests for strip_stale_thinking_blocks — removes thinking/redacted_thinking
|
||||
blocks from non-last assistant entries to reduce transcript bloat."""
|
||||
|
||||
def _asst_entry(
|
||||
self, msg_id: str, content: list, uuid: str = "u1", parent: str = ""
|
||||
) -> dict:
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": uuid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": msg_id,
|
||||
"type": "message",
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
def _user_entry(self, text: str, uuid: str = "u0", parent: str = "") -> dict:
|
||||
return {
|
||||
"type": "user",
|
||||
"uuid": uuid,
|
||||
"parentUuid": parent,
|
||||
"message": {"role": "user", "content": text},
|
||||
}
|
||||
|
||||
def test_strips_thinking_from_older_assistant(self) -> None:
|
||||
"""Thinking blocks in non-last assistant entries should be removed."""
|
||||
old_asst = self._asst_entry(
|
||||
"msg_old",
|
||||
[
|
||||
{"type": "thinking", "thinking": "deep thoughts..."},
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "redacted_thinking", "data": "secret"},
|
||||
],
|
||||
uuid="a1",
|
||||
)
|
||||
new_asst = self._asst_entry(
|
||||
"msg_new",
|
||||
[
|
||||
{"type": "thinking", "thinking": "latest thoughts"},
|
||||
{"type": "text", "text": "world"},
|
||||
],
|
||||
uuid="a2",
|
||||
parent="a1",
|
||||
)
|
||||
content = _make_jsonl(old_asst, new_asst)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# Old assistant should have thinking blocks stripped
|
||||
old_content = lines[0]["message"]["content"]
|
||||
assert len(old_content) == 1
|
||||
assert old_content[0]["type"] == "text"
|
||||
|
||||
# New (last) assistant should be untouched
|
||||
new_content = lines[1]["message"]["content"]
|
||||
assert len(new_content) == 2
|
||||
assert new_content[0]["type"] == "thinking"
|
||||
assert new_content[1]["type"] == "text"
|
||||
|
||||
def test_preserves_last_assistant_thinking(self) -> None:
|
||||
"""The last assistant entry's thinking blocks must be preserved."""
|
||||
entry = self._asst_entry(
|
||||
"msg_only",
|
||||
[
|
||||
{"type": "thinking", "thinking": "must keep"},
|
||||
{"type": "text", "text": "response"},
|
||||
],
|
||||
)
|
||||
content = _make_jsonl(entry)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(lines[0]["message"]["content"]) == 2
|
||||
|
||||
def test_no_assistant_entries_returns_unchanged(self) -> None:
|
||||
"""Transcripts with only user entries should pass through unchanged."""
|
||||
user = self._user_entry("hello")
|
||||
content = _make_jsonl(user)
|
||||
assert strip_stale_thinking_blocks(content) == content
|
||||
|
||||
def test_empty_content_returns_unchanged(self) -> None:
|
||||
assert strip_stale_thinking_blocks("") == ""
|
||||
|
||||
def test_multiple_turns_strips_all_but_last(self) -> None:
|
||||
"""With 3 assistant turns, only the last keeps thinking blocks."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_1",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t1"},
|
||||
{"type": "text", "text": "a1"},
|
||||
],
|
||||
uuid="a1",
|
||||
),
|
||||
self._user_entry("q2", uuid="u2", parent="a1"),
|
||||
self._asst_entry(
|
||||
"msg_2",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t2"},
|
||||
{"type": "text", "text": "a2"},
|
||||
],
|
||||
uuid="a2",
|
||||
parent="u2",
|
||||
),
|
||||
self._user_entry("q3", uuid="u3", parent="a2"),
|
||||
self._asst_entry(
|
||||
"msg_3",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t3"},
|
||||
{"type": "text", "text": "a3"},
|
||||
],
|
||||
uuid="a3",
|
||||
parent="u3",
|
||||
),
|
||||
]
|
||||
content = _make_jsonl(*entries)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# msg_1: thinking stripped
|
||||
assert len(lines[0]["message"]["content"]) == 1
|
||||
assert lines[0]["message"]["content"][0]["type"] == "text"
|
||||
# msg_2: thinking stripped
|
||||
assert len(lines[2]["message"]["content"]) == 1
|
||||
# msg_3 (last): thinking preserved
|
||||
assert len(lines[4]["message"]["content"]) == 2
|
||||
assert lines[4]["message"]["content"][0]["type"] == "thinking"
|
||||
|
||||
def test_same_msg_id_multi_entry_turn(self) -> None:
|
||||
"""Multiple entries sharing the same message.id (same turn) are preserved."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_old",
|
||||
[{"type": "thinking", "thinking": "old"}],
|
||||
uuid="a1",
|
||||
),
|
||||
self._asst_entry(
|
||||
"msg_last",
|
||||
[{"type": "thinking", "thinking": "t_part1"}],
|
||||
uuid="a2",
|
||||
parent="a1",
|
||||
),
|
||||
self._asst_entry(
|
||||
"msg_last",
|
||||
[{"type": "text", "text": "response"}],
|
||||
uuid="a3",
|
||||
parent="a2",
|
||||
),
|
||||
]
|
||||
content = _make_jsonl(*entries)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
|
||||
# Old entry stripped
|
||||
assert lines[0]["message"]["content"] == []
|
||||
# Both entries of last turn (msg_last) preserved
|
||||
assert lines[1]["message"]["content"][0]["type"] == "thinking"
|
||||
assert lines[2]["message"]["content"][0]["type"] == "text"
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
from .transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,7 +30,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
if not cfg.claude_agent_use_resume:
|
||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -221,9 +221,21 @@ async def create_session(
|
||||
return session
|
||||
|
||||
|
||||
_meta_ttl_refresh_at: dict[str, float] = {}
|
||||
"""Tracks the last time the session meta key TTL was refreshed.
|
||||
|
||||
Used by `publish_chunk` to avoid refreshing on every single chunk
|
||||
(expensive). Refreshes at most once every 60 seconds per session.
|
||||
"""
|
||||
|
||||
_META_TTL_REFRESH_INTERVAL = 60 # seconds
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
turn_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream.
|
||||
|
||||
@@ -232,6 +244,9 @@ async def publish_chunk(
|
||||
Args:
|
||||
turn_id: Turn ID (per-turn UUID) identifying the stream
|
||||
chunk: The stream response chunk to publish
|
||||
session_id: Chat session ID — when provided, the session meta key
|
||||
TTL is refreshed periodically to prevent expiration during
|
||||
long-running turns (see SECRT-2178).
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
@@ -265,6 +280,23 @@ async def publish_chunk(
|
||||
# Set TTL on stream to match session metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
|
||||
# Periodically refresh session-related TTLs so they don't expire
|
||||
# during long-running turns. Without this, turns exceeding stream_ttl
|
||||
# (default 1h) lose their "running" status and stream data, making
|
||||
# the session invisible to the resume endpoint (empty on page reload).
|
||||
# Both meta key AND stream key are refreshed: the stream key's expire
|
||||
# above only fires when publish_chunk is called, but during long
|
||||
# sub-agent gaps (task_progress events don't produce chunks), neither
|
||||
# key gets refreshed.
|
||||
if session_id:
|
||||
now = time.perf_counter()
|
||||
last_refresh = _meta_ttl_refresh_at.get(session_id, 0)
|
||||
if now - last_refresh >= _META_TTL_REFRESH_INTERVAL:
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
_meta_ttl_refresh_at[session_id] = now
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
@@ -331,7 +363,7 @@ async def stream_and_publish(
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event)
|
||||
await publish_chunk(turn_id, event, session_id=session_id)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
@@ -800,6 +832,9 @@ async def mark_session_completed(
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
|
||||
_meta_ttl_refresh_at.pop(session_id, None)
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
@@ -4,12 +4,15 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data.platform_cost import PlatformCostEntry, log_platform_cost_safe
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
@@ -95,4 +98,47 @@ async def persist_and_record_usage(
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard
|
||||
if user_id and total_tokens > 0:
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
cost_float = float(cost_usd)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = (
|
||||
int(cost_float * 1_000_000) if cost_float is not None else None
|
||||
)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
tracking_type = "cost_usd"
|
||||
tracking_amount = cost_float
|
||||
else:
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
await log_platform_cost_safe(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
block_id="copilot",
|
||||
block_name=f"copilot:{log_prefix.strip(' []')}".rstrip(":"),
|
||||
provider="open_router",
|
||||
credential_id="copilot_system",
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=None,
|
||||
metadata={
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"source": "copilot",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.copilot.tracking import track_tool_called
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .ask_question import AskQuestionTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
@@ -55,6 +56,7 @@ logger = logging.getLogger(__name__)
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"ask_question": AskQuestionTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
|
||||
@@ -68,6 +68,9 @@ class AddUnderstandingTool(BaseTool):
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
|
||||
Note: This tool accepts **kwargs because its parameters are derived
|
||||
dynamically from the BusinessUnderstandingInput model schema.
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
@@ -77,23 +80,21 @@ class AddUnderstandingTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
filtered = {k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
if not any(v is not None for v in filtered.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
input_data = BusinessUnderstandingInput(**filtered)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
updated_fields = [k for k, v in filtered.items() if v is not None]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user