mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(platform): resolve dev merge conflicts for integration link image PR
This commit is contained in:
1
.agents/skills
Symbolic link
1
.agents/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
10
.claude/settings.json
Normal file
10
.claude/settings.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allowedTools": [
|
||||
"Read", "Grep", "Glob",
|
||||
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
|
||||
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
|
||||
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,28 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
|
||||
|
||||
### Running coverage locally
|
||||
|
||||
**Backend** (from `autogpt_platform/backend/`):
|
||||
```bash
|
||||
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
|
||||
```
|
||||
|
||||
**Frontend** (from `autogpt_platform/frontend/`):
|
||||
```bash
|
||||
pnpm vitest run --coverage
|
||||
```
|
||||
|
||||
### When codecov/patch fails
|
||||
|
||||
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
|
||||
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
|
||||
3. Run coverage locally to verify, commit, push.
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
@@ -0,0 +1,224 @@
|
||||
---
|
||||
name: write-frontend-tests
|
||||
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
|
||||
user-invocable: true
|
||||
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Write Frontend Tests
|
||||
|
||||
Analyze the current branch's frontend changes, plan integration tests, and write them.
|
||||
|
||||
## References
|
||||
|
||||
Before writing any tests, read the testing rules and conventions:
|
||||
|
||||
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
|
||||
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
|
||||
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
|
||||
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
|
||||
|
||||
## Step 1: Identify changed frontend files
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${ARGUMENTS:-dev}"
|
||||
cd autogpt_platform/frontend
|
||||
|
||||
# Get changed frontend files (excluding generated, config, and test files)
|
||||
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
|
||||
| grep -v '__generated__' \
|
||||
| grep -v '__tests__' \
|
||||
| grep -v '\.test\.' \
|
||||
| grep -v '\.stories\.' \
|
||||
| grep -v '\.spec\.'
|
||||
```
|
||||
|
||||
Also read the diff to understand what changed:
|
||||
|
||||
```bash
|
||||
git diff "$BASE_BRANCH"...HEAD --stat -- src/
|
||||
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
```
|
||||
|
||||
## Step 2: Categorize changes and find test targets
|
||||
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Hooks with non-trivial business logic
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
|
||||
## Step 3: Check for existing tests
|
||||
|
||||
For each test target, check if tests already exist:
|
||||
|
||||
```bash
|
||||
# For a page at src/app/(platform)/library/page.tsx
|
||||
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
|
||||
|
||||
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
|
||||
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
|
||||
```
|
||||
|
||||
Note which targets have no tests (need new files) vs which have tests that need updating.
|
||||
|
||||
## Step 4: Identify API endpoints used
|
||||
|
||||
For each test target, find which API hooks are used:
|
||||
|
||||
```bash
|
||||
# Find generated API hook imports in the changed files
|
||||
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
|
||||
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
|
||||
```
|
||||
|
||||
For each API hook found, locate the corresponding MSW handler:
|
||||
|
||||
```bash
|
||||
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
|
||||
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
|
||||
```
|
||||
|
||||
List every MSW handler you will need (200 for happy path, 4xx for error paths).
|
||||
|
||||
## Step 5: Write the test plan
|
||||
|
||||
Before writing code, output a plan as a numbered list:
|
||||
|
||||
```
|
||||
Test plan for [branch name]:
|
||||
|
||||
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
|
||||
- Renders page with agent list (MSW 200)
|
||||
- Shows loading state
|
||||
- Shows error state (MSW 422)
|
||||
- Handles empty agent list
|
||||
|
||||
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
|
||||
- Filters agents by search query
|
||||
- Shows no results message
|
||||
- Clears search
|
||||
|
||||
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
|
||||
- Add test for new "duplicate" action
|
||||
```
|
||||
|
||||
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
|
||||
|
||||
## Step 6: Write the tests
|
||||
|
||||
For each test file in the plan, follow these conventions:
|
||||
|
||||
### File structure
|
||||
|
||||
```tsx
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
// Import MSW handlers for endpoints the page uses
|
||||
import {
|
||||
getGetV2ListLibraryAgentsMockHandler200,
|
||||
getGetV2ListLibraryAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/library/library.msw";
|
||||
// Import the component under test
|
||||
import LibraryPage from "../page";
|
||||
|
||||
describe("LibraryPage", () => {
|
||||
test("renders agent list from API", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler200());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/my agents/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows error state on API failure", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler422());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/error/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Rules
|
||||
|
||||
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
|
||||
- Use `server.use()` to set up MSW handlers BEFORE rendering
|
||||
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
|
||||
- Use `getBy*` only for elements that are immediately present in the DOM
|
||||
- Use `screen` queries — do NOT destructure from `render()`
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
|
||||
### Test location
|
||||
|
||||
```
|
||||
# For pages: __tests__/ next to page.tsx
|
||||
src/app/(platform)/library/__tests__/main.test.tsx
|
||||
|
||||
# For complex standalone components: __tests__/ inside component folder
|
||||
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
|
||||
|
||||
# For pure helpers: co-located .test.ts
|
||||
src/app/(platform)/library/helpers.test.ts
|
||||
```
|
||||
|
||||
### Custom MSW overrides
|
||||
|
||||
When the auto-generated faker data is not enough, override with specific data:
|
||||
|
||||
```tsx
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [
|
||||
{ id: "1", name: "Test Agent", description: "A test agent" },
|
||||
],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
);
|
||||
```
|
||||
|
||||
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
|
||||
|
||||
## Step 7: Run and verify
|
||||
|
||||
After writing all tests:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
Then run the full checks:
|
||||
|
||||
```bash
|
||||
pnpm format
|
||||
pnpm lint
|
||||
pnpm types
|
||||
```
|
||||
78
.github/workflows/classic-autogpt-ci.yml
vendored
78
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,19 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +27,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +54,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +81,13 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +101,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
tests/unit tests/integration 2>&1 | tee test_output.txt
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
|
||||
|
||||
test_failure=${PIPESTATUS[0]}
|
||||
|
||||
|
||||
44
.github/workflows/classic-autogpts-ci.yml
vendored
44
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -10,10 +10,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -21,10 +20,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +33,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +49,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
256
.github/workflows/classic-benchmark-ci.yml
vendored
256
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,18 +1,24 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -23,95 +29,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -124,53 +51,120 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
189
.github/workflows/classic-forge-ci.yml
vendored
189
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,131 +23,60 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: poetry run playwright install chromium
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +90,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -269,12 +269,14 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
fi
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
@@ -287,11 +289,13 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-backend
|
||||
files: ./autogpt_platform/backend/coverage.xml
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,3 +148,11 @@ jobs:
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend
|
||||
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml
|
||||
|
||||
25
.github/workflows/platform-fullstack-ci.yml
vendored
25
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -179,21 +179,30 @@ jobs:
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
|
||||
# (docker compose config on some versions drops this arg)
|
||||
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
|
||||
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
|
||||
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
|
||||
fi
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
NEXT_PUBLIC_SOURCEMAPS: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
@@ -279,6 +288,11 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -289,6 +303,15 @@ jobs:
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend-e2e
|
||||
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
|
||||
disable_search: true
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -16,6 +17,7 @@ log-ingestion.txt
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
!autogpt_platform/frontend/public/notification.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,9 +183,13 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
|
||||
36
.gitleaks.toml
Normal file
36
.gitleaks.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
title = "AutoGPT Gitleaks Config"
|
||||
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[allowlist]
|
||||
description = "Global allowlist"
|
||||
paths = [
|
||||
# Template/example env files (no real secrets)
|
||||
'''\.env\.(default|example|template)$''',
|
||||
# Lock files
|
||||
'''pnpm-lock\.yaml$''',
|
||||
'''poetry\.lock$''',
|
||||
# Secrets baseline
|
||||
'''\.secrets\.baseline$''',
|
||||
# Build artifacts and caches (should not be committed)
|
||||
'''__pycache__/''',
|
||||
'''classic/frontend/build/''',
|
||||
# Docker dev setup (local dev JWTs/keys only)
|
||||
'''autogpt_platform/db/docker/''',
|
||||
# Load test configs (dev JWTs)
|
||||
'''load-tests/configs/''',
|
||||
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
|
||||
'''(_test|test_.*|conftest)\.py$''',
|
||||
# Documentation (only contains placeholder keys in curl/API examples)
|
||||
'''docs/.*\.md$''',
|
||||
# Firebase config (public API keys by design)
|
||||
'''google-services\.json$''',
|
||||
'''classic/frontend/(lib|web)/''',
|
||||
]
|
||||
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
|
||||
regexes = [
|
||||
'''dvziYgz0KSK8FENhju0ZYi8''',
|
||||
# LLM model name enum values falsely flagged as API keys
|
||||
'''Llama-\d.*Instruct''',
|
||||
]
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -23,9 +23,15 @@ repos:
|
||||
- id: detect-secrets
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
args: ["--baseline", ".secrets.baseline"]
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
name: Detect secrets (gitleaks)
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -84,51 +90,16 @@ repos:
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
|
||||
poetry -C classic install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
@@ -223,26 +194,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -256,26 +211,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -311,29 +253,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -360,26 +283,9 @@ repos:
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# name: Run tests - Classic (excl. slow tests)
|
||||
# alias: pytest-classic
|
||||
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
|
||||
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
467
.secrets.baseline
Normal file
467
.secrets.baseline
Normal file
@@ -0,0 +1,467 @@
|
||||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitLabTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "IPPublicDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "OpenAIDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "PypiTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TelegramBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.regex.should_exclude_file",
|
||||
"pattern": [
|
||||
"\\.env$",
|
||||
"pnpm-lock\\.yaml$",
|
||||
"\\.env\\.(default|example|template)$",
|
||||
"__pycache__",
|
||||
"_test\\.py$",
|
||||
"test_.*\\.py$",
|
||||
"conftest\\.py$",
|
||||
"poetry\\.lock$",
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
],
|
||||
"results": {
|
||||
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
|
||||
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
|
||||
"is_verified": false,
|
||||
"line_number": 289
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
|
||||
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
|
||||
"is_verified": false,
|
||||
"line_number": 29
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
|
||||
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
|
||||
"is_verified": false,
|
||||
"line_number": 12
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/checks.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 108
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/ci.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
|
||||
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
|
||||
"is_verified": false,
|
||||
"line_number": 123
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
|
||||
"is_verified": false,
|
||||
"line_number": 42
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
|
||||
"is_verified": false,
|
||||
"line_number": 344
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
|
||||
"is_verified": false,
|
||||
"line_number": 534
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/docs.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
|
||||
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
|
||||
"is_verified": false,
|
||||
"line_number": 203
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
|
||||
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
|
||||
"is_verified": false,
|
||||
"line_number": 57
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
|
||||
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
|
||||
"is_verified": false,
|
||||
"line_number": 53
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/medium.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
|
||||
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
|
||||
"is_verified": false,
|
||||
"line_number": 131
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
|
||||
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
|
||||
"is_verified": false,
|
||||
"line_number": 100
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/talking_head.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
|
||||
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
|
||||
"is_verified": false,
|
||||
"line_number": 113
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
|
||||
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
|
||||
"is_verified": false,
|
||||
"line_number": 17
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/util/cache.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/util/cache.py",
|
||||
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
|
||||
"is_verified": false,
|
||||
"line_number": 449
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
|
||||
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
|
||||
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 8
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 7
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 192
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
|
||||
"is_verified": false,
|
||||
"line_number": 102
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 103
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
|
||||
"is_verified": false,
|
||||
"line_number": 73
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
|
||||
"is_verified": false,
|
||||
"line_number": 75
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
|
||||
"is_verified": false,
|
||||
"line_number": 77
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
|
||||
"is_verified": false,
|
||||
"line_number": 79
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
|
||||
"is_verified": false,
|
||||
"line_number": 81
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
|
||||
"is_verified": false,
|
||||
"line_number": 83
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/constants.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
|
||||
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
|
||||
"is_verified": false,
|
||||
"line_number": 10
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
|
||||
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
|
||||
"is_verified": false,
|
||||
"line_number": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-02T13:10:54Z"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
@@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
|
||||
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
|
||||
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
120
autogpt_platform/AGENTS.md
Normal file
120
autogpt_platform/AGENTS.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1,120 +1 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@AGENTS.md
|
||||
|
||||
@@ -178,6 +178,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
227
autogpt_platform/backend/AGENTS.md
Normal file
227
autogpt_platform/backend/AGENTS.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1,227 +1 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@AGENTS.md
|
||||
|
||||
@@ -31,7 +31,10 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -618,6 +621,11 @@ async def delete_credential(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -9,11 +9,14 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
@@ -143,4 +158,102 @@ async def reset_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s",
|
||||
request.user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved_email = None
|
||||
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
|
||||
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email,
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -89,6 +89,7 @@ def test_get_rate_limit(
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure returns 404 (user unverifiable)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -11,15 +11,16 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -110,6 +111,23 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -118,6 +136,7 @@ class CreateSessionResponse(BaseModel):
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -138,6 +157,7 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -248,6 +268,7 @@ async def list_sessions(
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -256,22 +277,28 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -420,6 +447,7 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -433,8 +461,9 @@ async def get_copilot_usage(
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
@@ -442,6 +471,7 @@ async def get_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -493,7 +523,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -527,10 +557,13 @@ async def reset_copilot_usage(
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -606,6 +639,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -716,7 +750,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
@@ -811,6 +845,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -1174,7 +1209,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -331,14 +332,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -469,3 +484,98 @@ def test_suggested_prompts_empty_prompts(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
@@ -40,11 +40,15 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -110,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -148,6 +153,7 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +230,9 @@ async def callback(
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
@@ -238,6 +247,7 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
@@ -332,6 +342,11 @@ async def delete_credentials(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
@@ -342,6 +357,11 @@ async def delete_credentials(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -276,3 +277,294 @@ class TestCreateCredentialNoSecretInResponse:
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
@@ -487,6 +487,11 @@ async def create_library_agent(
|
||||
"topIntegrations": SafeJson(
|
||||
library_model._compute_top_integrations(graph_entry)
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
include=library_agent_include(
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_onboarding_profile_success(mocker):
|
||||
mock_extract = mocker.patch(
|
||||
"backend.api.features.v1.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.v1.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
|
||||
user_name="John",
|
||||
user_role="Founder/CEO",
|
||||
pain_points=["Finding leads"],
|
||||
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
|
||||
)
|
||||
mock_upsert.return_value = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={
|
||||
"user_name": "John",
|
||||
"user_role": "Founder/CEO",
|
||||
"pain_points": ["Finding leads", "Email & outreach"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once()
|
||||
|
||||
|
||||
def test_onboarding_profile_missing_fields():
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={"user_name": "John"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -63,12 +63,17 @@ from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
complete_onboarding_step,
|
||||
complete_re_run_agent,
|
||||
format_onboarding_for_extraction,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.tally import extract_business_understanding
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
@@ -282,35 +287,33 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding status check."""
|
||||
class OnboardingProfileRequest(pydantic.BaseModel):
|
||||
"""Request body for onboarding profile submission."""
|
||||
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
user_name: str = pydantic.Field(min_length=1, max_length=100)
|
||||
user_role: str = pydantic.Field(min_length=1, max_length=100)
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding completion check."""
|
||||
|
||||
is_completed: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
"/onboarding/completed",
|
||||
summary="Check if onboarding is completed",
|
||||
tags=["onboarding", "public"],
|
||||
response_model=OnboardingStatusResponse,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled(
|
||||
async def is_onboarding_completed(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
|
||||
)
|
||||
|
||||
|
||||
@@ -325,6 +328,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await reset_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/onboarding/profile",
|
||||
summary="Submit onboarding profile",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def submit_onboarding_profile(
|
||||
data: OnboardingProfileRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
formatted = format_onboarding_for_extraction(
|
||||
user_name=data.user_name,
|
||||
user_role=data.user_role,
|
||||
pain_points=data.pain_points,
|
||||
)
|
||||
|
||||
try:
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
except Exception:
|
||||
understanding_input = BusinessUnderstandingInput.model_construct()
|
||||
|
||||
# Ensure the direct fields are set even if LLM missed them
|
||||
understanding_input.user_name = data.user_name
|
||||
understanding_input.user_role = data.user_role
|
||||
if not understanding_input.pain_points:
|
||||
understanding_input.pain_points = data.pain_points
|
||||
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
|
||||
@@ -118,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
|
||||
@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -146,6 +146,21 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -232,11 +247,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -367,7 +382,9 @@ class AutoPilotBlock(Block):
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id, dry_run=input_data.dry_run
|
||||
)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -205,6 +205,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Z.ai (Zhipu) models
|
||||
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
|
||||
ZAI_GLM_4_5 = "z-ai/glm-4.5"
|
||||
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
|
||||
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
|
||||
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
|
||||
ZAI_GLM_4_6 = "z-ai/glm-4.6"
|
||||
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
|
||||
ZAI_GLM_4_7 = "z-ai/glm-4.7"
|
||||
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
|
||||
ZAI_GLM_5 = "z-ai/glm-5"
|
||||
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
|
||||
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
@@ -630,6 +643,43 @@ MODEL_METADATA = {
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# https://openrouter.ai/models?q=z-ai
|
||||
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
|
||||
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
|
||||
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
|
||||
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6: ModelMetadata(
|
||||
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7: ModelMetadata(
|
||||
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
|
||||
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_5: ModelMetadata(
|
||||
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
@@ -724,6 +774,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -739,6 +792,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -972,6 +1028,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1031,12 +1089,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1073,12 +1127,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1108,6 +1158,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1144,6 +1196,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -2011,6 +2066,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description=(
|
||||
"Database hostname or IP address. "
|
||||
"Treated as a secret to avoid leaking infrastructure details. "
|
||||
"Private/internal IPs are blocked (SSRF protection)."
|
||||
),
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
truncated: bool = SchemaField(
|
||||
description=(
|
||||
"True when the result set was capped by max_rows, "
|
||||
"indicating additional rows exist in the database"
|
||||
)
|
||||
)
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
("truncated", False),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
False,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected, truncated = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
yield "truncated", truncated
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield (
|
||||
"error",
|
||||
self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
),
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield (
|
||||
"error",
|
||||
(
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
|
||||
"LOAD",
|
||||
# MySQL REPLACE is INSERT-or-UPDATE; data modification
|
||||
"REPLACE",
|
||||
# ANSI MERGE (UPSERT) modifies data
|
||||
"MERGE",
|
||||
# MSSQL BULK INSERT loads external files into tables
|
||||
"BULK",
|
||||
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
|
||||
"EXEC",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
|
||||
# login_timeout). This bounds how long the driver waits to establish a TCP
|
||||
# connection to the database server. It is separate from the per-statement
|
||||
# timeout configured via SET commands inside _configure_session().
|
||||
_CONNECT_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# NaN / Infinity are not valid JSON numbers; serialize as strings.
|
||||
if value.is_nan() or value.is_infinite():
|
||||
return str(value)
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect.
|
||||
|
||||
Timeout limitations by database:
|
||||
|
||||
* **PostgreSQL** – ``statement_timeout`` reliably cancels any running
|
||||
statement (SELECT or DML) after the configured duration.
|
||||
* **MySQL** – ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
|
||||
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
|
||||
this hint; they rely on the server's ``wait_timeout`` /
|
||||
``interactive_timeout`` instead. There is no session-level setting in
|
||||
MySQL that reliably cancels long-running writes.
|
||||
* **MSSQL** – ``SET LOCK_TIMEOUT`` only limits how long the server waits
|
||||
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
|
||||
joins) that do not block on locks will *not* be cancelled. MSSQL has
|
||||
no session-level ``statement_timeout`` equivalent; the closest
|
||||
mechanism is Resource Governor (requires sysadmin configuration) or
|
||||
``CONTEXT_INFO``-based external monitoring.
|
||||
|
||||
Note: SQLite is not supported by this block. The ``_configure_session``
|
||||
function is a no-op for unrecognised dialect names, so an SQLite engine
|
||||
would skip all SET commands silently. The block's ``DatabaseType`` enum
|
||||
intentionally excludes SQLite.
|
||||
"""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
|
||||
# CPU-bound queries without lock contention are NOT cancelled.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a query inside an explicit transaction, returning results.
|
||||
|
||||
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
|
||||
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
|
||||
indicating that additional rows may exist in the result set.
|
||||
"""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
truncated = len(rows) == max_rows
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
try:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected, truncated
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(
|
||||
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
|
||||
)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -488,6 +488,154 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == {"result": "test"}
|
||||
|
||||
|
||||
class TestAIConversationBlockValidation:
|
||||
"""Test that AIConversationBlock validates inputs before calling the LLM."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_and_empty_prompt_raises_error(self):
|
||||
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_with_prompt_succeeds(self):
|
||||
"""Empty messages but a non-empty prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "OK"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="Hello, how are you?",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
|
||||
"""Non-empty messages with no prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "response from conversation"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "response from conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_content_raises_error(self):
|
||||
"""Messages with empty content strings should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_whitespace_content_raises_error(self):
|
||||
"""Messages with whitespace-only content should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": " "}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_entry_raises_error(self):
|
||||
"""Messages list containing None should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[None],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_dict_raises_error(self):
|
||||
"""Messages list containing empty dict should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_content_raises_error(self):
|
||||
"""Messages with content=None should not crash with AttributeError."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": None}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
class TestAITextSummarizerValidation:
|
||||
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -0,0 +1,633 @@
|
||||
"""Unit tests for baseline service pure-logic helpers.
|
||||
|
||||
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_conversation_updater,
|
||||
_BaselineStreamState,
|
||||
_compress_session_messages,
|
||||
_ThinkingStripper,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.prompt import CompressResult
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
assert state.pending_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
assert state.turn_completion_tokens == 0
|
||||
assert state.text_block_id # Should be a UUID string
|
||||
|
||||
def test_mutable_fields(self):
|
||||
state = _BaselineStreamState()
|
||||
state.assistant_text = "hello"
|
||||
state.turn_prompt_tokens = 100
|
||||
state.turn_completion_tokens = 50
|
||||
assert state.assistant_text == "hello"
|
||||
assert state.turn_prompt_tokens == 100
|
||||
assert state.turn_completion_tokens == 50
|
||||
|
||||
|
||||
class TestBaselineConversationUpdater:
|
||||
"""Tests for _baseline_conversation_updater which updates the OpenAI
|
||||
message list and transcript builder after each LLM call."""
|
||||
|
||||
def _make_transcript_builder(self) -> TranscriptBuilder:
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("test question")
|
||||
return builder
|
||||
|
||||
def test_text_only_response(self):
|
||||
"""When the LLM returns text without tool calls, the updater appends
|
||||
a single assistant message and records it in the transcript."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Hello, world!",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Hello, world!"
|
||||
# Transcript should have user + assistant
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
def test_tool_calls_response(self):
|
||||
"""When the LLM returns tool calls, the updater appends the assistant
|
||||
message with tool_calls and tool result messages."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Let me search...",
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1",
|
||||
tool_name="search",
|
||||
content="Found result",
|
||||
),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Messages: assistant (with tool_calls) + tool result
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Let me search..."
|
||||
assert len(messages[0]["tool_calls"]) == 1
|
||||
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
|
||||
assert messages[1]["role"] == "tool"
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[1]["content"] == "Found result"
|
||||
|
||||
# Transcript: user + assistant(tool_use) + user(tool_result)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_tool_calls_without_text(self):
|
||||
"""Tool calls without accompanying text should still work."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="run", arguments="{}"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "content" not in messages[0] # No text content
|
||||
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
|
||||
|
||||
def test_no_text_no_tools(self):
|
||||
"""When the response has no text and no tool calls, nothing is appended."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 0
|
||||
# Only the user entry from setup
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool calls in a single response are all recorded."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
|
||||
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# 1 assistant + 2 tool results
|
||||
assert len(messages) == 3
|
||||
assert len(messages[0]["tool_calls"]) == 2
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[2]["tool_call_id"] == "tc_2"
|
||||
|
||||
def test_invalid_tool_arguments_handled(self):
|
||||
"""Tool call with invalid JSON arguments: the arguments field is
|
||||
stored as-is in the message, and orjson failure falls back to {}
|
||||
in the transcript content_blocks."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should not raise — invalid JSON falls back to {} in transcript
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
|
||||
|
||||
|
||||
class TestCompressSessionMessagesPreservesToolCalls:
|
||||
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
|
||||
|
||||
Compression serialises ChatMessage to dict for ``compress_context`` and
|
||||
reifies the result back to ChatMessage. A regression that drops
|
||||
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
|
||||
list and break downstream tool-execution rounds.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compressed_output_keeps_tool_calls_and_ids(self):
|
||||
# Simulate compression that returns a summary + the most recent
|
||||
# assistant(tool_call) + tool(tool_result) intact.
|
||||
summary = {"role": "system", "content": "prior turns: user asked X"}
|
||||
assistant_with_tc = {
|
||||
"role": "assistant",
|
||||
"content": "calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"y"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "tc_abc",
|
||||
"content": "search result",
|
||||
}
|
||||
|
||||
compress_result = CompressResult(
|
||||
messages=[summary, assistant_with_tc, tool_result],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
# Input: messages that should be compressed.
|
||||
input_messages = [
|
||||
ChatMessage(role="user", content="q1"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q":"y"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
tool_call_id="tc_abc",
|
||||
content="search result",
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=compress_result),
|
||||
):
|
||||
compressed = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
# Summary, assistant(tool_calls), tool(tool_call_id).
|
||||
assert len(compressed) == 3
|
||||
# Assistant message must keep its tool_calls intact.
|
||||
assistant_msg = compressed[1]
|
||||
assert assistant_msg.role == "assistant"
|
||||
assert assistant_msg.tool_calls is not None
|
||||
assert len(assistant_msg.tool_calls) == 1
|
||||
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
|
||||
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
|
||||
# Tool-role message must keep tool_call_id for OpenAI linkage.
|
||||
tool_msg = compressed[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.tool_call_id == "tc_abc"
|
||||
assert tool_msg.content == "search result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncompressed_passthrough_keeps_fields(self):
|
||||
"""When compression is a no-op (was_compacted=False), the original
|
||||
messages must be returned unchanged — including tool_calls."""
|
||||
input_messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="c",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
|
||||
]
|
||||
|
||||
noop_result = CompressResult(
|
||||
messages=[], # ignored when was_compacted=False
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=noop_result),
|
||||
):
|
||||
out = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
assert out is input_messages # same list returned
|
||||
assert out[0].tool_calls is not None
|
||||
assert out[0].tool_calls[0]["id"] == "t1"
|
||||
assert out[1].tool_call_id == "t1"
|
||||
|
||||
|
||||
# ---- _ThinkingStripper tests ---- #
|
||||
|
||||
|
||||
def test_thinking_stripper_basic_thinking_tag() -> None:
|
||||
"""<thinking>...</thinking> blocks are fully stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("<thinking>internal reasoning here</thinking>Hello!") == "Hello!"
|
||||
|
||||
|
||||
def test_thinking_stripper_internal_reasoning_tag() -> None:
|
||||
"""<internal_reasoning>...</internal_reasoning> blocks (Gemini) are stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("<internal_reasoning>step by step</internal_reasoning>Answer")
|
||||
== "Answer"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_split_across_chunks() -> None:
|
||||
"""Tags split across multiple chunks are handled correctly."""
|
||||
s = _ThinkingStripper()
|
||||
out = s.process("Hello <thin")
|
||||
out += s.process("king>secret</thinking> world")
|
||||
assert out == "Hello world"
|
||||
|
||||
|
||||
def test_thinking_stripper_plain_text_preserved() -> None:
|
||||
"""Plain text with the word 'thinking' is not stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("I am thinking about this problem")
|
||||
== "I am thinking about this problem"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_multiple_blocks() -> None:
|
||||
"""Multiple reasoning blocks in one stream are all stripped."""
|
||||
s = _ThinkingStripper()
|
||||
result = s.process(
|
||||
"A<thinking>x</thinking>B<internal_reasoning>y</internal_reasoning>C"
|
||||
)
|
||||
assert result == "ABC"
|
||||
|
||||
|
||||
def test_thinking_stripper_flush_discards_unclosed() -> None:
|
||||
"""Unclosed reasoning block is discarded on flush."""
|
||||
s = _ThinkingStripper()
|
||||
s.process("Start<thinking>never closed")
|
||||
flushed = s.flush()
|
||||
assert "never closed" not in flushed
|
||||
|
||||
|
||||
def test_thinking_stripper_empty_block() -> None:
|
||||
"""Empty reasoning blocks are handled gracefully."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("Before<thinking></thinking>After") == "BeforeAfter"
|
||||
|
||||
|
||||
# ---- _filter_tools_by_permissions tests ---- #
|
||||
|
||||
|
||||
def _make_tool(name: str) -> ChatCompletionToolParam:
|
||||
"""Build a minimal OpenAI ChatCompletionToolParam."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={"name": name, "parameters": {}},
|
||||
)
|
||||
|
||||
|
||||
class TestFilterToolsByPermissions:
|
||||
"""Tests for _filter_tools_by_permissions."""
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_empty_permissions_returns_all(self, _mock_names):
|
||||
"""Empty permissions (no filtering) returns every tool unchanged."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("web_fetch")]
|
||||
perms = CopilotPermissions()
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert result == tools
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_allowlist_keeps_only_matching(self, _mock_names):
|
||||
"""Explicit allowlist (tools_exclude=False) keeps only listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "web_fetch"
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_blacklist_excludes_listed(self, _mock_names):
|
||||
"""Blacklist (tools_exclude=True) removes only the listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "bash_exec" not in names
|
||||
assert "run_block" in names
|
||||
assert "web_fetch" in names
|
||||
assert len(result) == 2
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_unknown_tool_name_filtered_out(self, _mock_names):
|
||||
"""A tool whose name is not in all_known_tool_names is dropped."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("unknown_tool")]
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "unknown_tool" not in names
|
||||
assert names == ["run_block"]
|
||||
|
||||
|
||||
# ---- _prepare_baseline_attachments tests ---- #
|
||||
|
||||
|
||||
class TestPrepareBaselineAttachments:
|
||||
"""Tests for _prepare_baseline_attachments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file_ids(self):
|
||||
"""Empty file_ids returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp")
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_user_id(self):
|
||||
"""Empty user_id returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["file1"], "", "sess1", "/tmp"
|
||||
)
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_vision_blocks(self):
|
||||
"""A PNG image within size limits is returned as a base64 vision block."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "photo.png"
|
||||
fake_info.mime_type = "image/png"
|
||||
fake_info.size_bytes = 1024
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp/workdir"
|
||||
)
|
||||
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "image"
|
||||
assert blocks[0]["source"]["media_type"] == "image/png"
|
||||
assert blocks[0]["source"]["type"] == "base64"
|
||||
assert "photo.png" in hint
|
||||
assert "embedded as image" in hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_file_saved_to_working_dir(self, tmp_path):
|
||||
"""A non-image file is written to working_dir."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "data.csv"
|
||||
fake_info.mime_type = "text/csv"
|
||||
fake_info.size_bytes = 42
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert blocks == []
|
||||
assert "data.csv" in hint
|
||||
assert "saved to" in hint
|
||||
saved = tmp_path / "data.csv"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"col1,col2\na,b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found_skipped(self):
|
||||
"""When get_file_info returns None the file is silently skipped."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["missing_id"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_manager_error(self):
|
||||
"""When get_workspace_manager raises, returns empty results."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(side_effect=RuntimeError("connection failed")),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
@@ -0,0 +1,667 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
|
||||
import json as stdlib_json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
def _make_transcript_content(*roles: str) -> str:
|
||||
"""Build a minimal valid JSONL transcript from role names."""
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
uid = f"uuid-{i}"
|
||||
entry: dict = {
|
||||
"type": role,
|
||||
"uuid": uid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": f"{role} message {i}"}],
|
||||
},
|
||||
}
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_differ(self):
|
||||
"""Sanity: the two tiers are actually distinct in production config."""
|
||||
assert config.model != config.fast_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=0,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
class TestUploadFinalTranscript:
|
||||
"""``_upload_final_transcript`` serialises and calls storage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uploads_valid_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
call_kwargs = upload_mock.await_args.kwargs
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=0,
|
||||
)
|
||||
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_upload_exceptions(self):
|
||||
"""Upload failures should not propagate (flow continues for the user)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordTurnToTranscript:
|
||||
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
|
||||
|
||||
def test_records_final_assistant_text(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text="hello there",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "hello there" in jsonl
|
||||
assert STOP_REASON_END_TURN in jsonl
|
||||
|
||||
def test_records_tool_use_then_tool_result(self):
|
||||
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="use a tool")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
|
||||
],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# user, assistant(tool_use), user(tool_result) = 3 entries
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert STOP_REASON_TOOL_USE in jsonl
|
||||
assert "tool_use" in jsonl
|
||||
assert "tool_result" in jsonl
|
||||
assert "call-1" in jsonl
|
||||
|
||||
def test_records_nothing_on_empty_response(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_malformed_tool_args_dont_crash(self):
|
||||
"""Bad JSON in tool arguments falls back to {} without raising."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""End-to-end: load prior → append new turn → upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New user turn.
|
||||
builder.append_user(content="new question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# New assistant turn.
|
||||
response = LLMLoopResponse(
|
||||
response_text="new answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
"""Backfill only runs when the last entry is not already assistant."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
# Simulate the backfill guard from stream_chat_completion_baseline.
|
||||
assistant_text = "partial text before error"
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "partial text before error" in builder.to_jsonl()
|
||||
|
||||
# Second invocation: the guard must prevent double-append.
|
||||
initial_count = builder.entry_count
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "duplicate"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
def test_upload_allowed_for_user_with_coverage(self):
|
||||
assert should_upload_transcript("user-1", True) is True
|
||||
|
||||
def test_upload_skipped_when_no_user(self):
|
||||
assert should_upload_transcript(None, True) is False
|
||||
|
||||
def test_upload_skipped_when_empty_user(self):
|
||||
assert should_upload_transcript("", True) is False
|
||||
|
||||
def test_upload_skipped_without_coverage(self):
|
||||
"""Partial transcript must never clobber a more complete stored one."""
|
||||
assert should_upload_transcript("user-1", False) is False
|
||||
|
||||
def test_upload_skipped_when_no_user_and_no_coverage(self):
|
||||
assert should_upload_transcript(None, False) is False
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
driving each step through the real helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
|
||||
# --- 2. Append a new user turn + a new assistant response ---
|
||||
builder.append_user(content="follow-up question")
|
||||
_record_turn_to_transcript(
|
||||
LLMLoopResponse(
|
||||
response_text="follow-up answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
),
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
"""Anonymous (user_id=None) → upload gate must return False."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
@@ -8,18 +8,35 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
# Per-request routing mode for a single chat turn.
|
||||
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
|
||||
# - 'extended_thinking': route to the Claude Agent SDK path with the default
|
||||
# (opus) model.
|
||||
# ``None`` means "no override"; the server falls back to the Claude Code
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
@@ -77,11 +94,11 @@ class ChatConfig(BaseSettings):
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
@@ -178,7 +195,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
|
||||
@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
|
||||
or ``tool-outputs/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
# the project dir and the tool directory.
|
||||
# Accept both "tool-results" (SDK's persisted outputs) and
|
||||
# "tool-outputs" (the model sometimes confuses workspace paths
|
||||
# with filesystem paths and generates this variant).
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
and parts[1] in ("tool-results", "tool-outputs")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_outputs_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
|
||||
@@ -18,7 +18,14 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionInfo,
|
||||
ChatSessionMetadata,
|
||||
cache_chat_session,
|
||||
)
|
||||
from .model import get_chat_session as get_chat_session_cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,6 +42,7 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
metadata: ChatSessionMetadata | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -43,6 +51,7 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
@@ -57,7 +66,12 @@ async def update_chat_session(
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
"""Update a chat session's mutable fields.
|
||||
|
||||
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
|
||||
it is set once at creation time and treated as immutable for the lifetime
|
||||
of the session.
|
||||
"""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
@@ -367,8 +381,11 @@ async def update_tool_message_content(
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
Updates the Redis cache in-place instead of invalidating it.
|
||||
Invalidation would delete the key, creating a window where concurrent
|
||||
``get_chat_session`` calls re-populate the cache from DB — potentially
|
||||
with stale data if the DB write from the previous turn hasn't propagated.
|
||||
This race caused duplicate user messages on the next turn.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
@@ -379,5 +396,13 @@ async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
# Update cache in-place rather than invalidating to avoid a
|
||||
# race window where the empty cache gets re-populated with
|
||||
# stale data by a concurrent get_chat_session call.
|
||||
session = await get_chat_session_cached(session_id)
|
||||
if session and session.messages:
|
||||
for msg in reversed(session.messages):
|
||||
if msg.role == "assistant":
|
||||
msg.duration_ms = duration_ms
|
||||
break
|
||||
await cache_chat_session(session)
|
||||
|
||||
54
autogpt_platform/backend/backend/copilot/db_test.py
Normal file
54
autogpt_platform/backend/backend/copilot/db_test.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
|
||||
from .db import set_turn_duration
|
||||
from .model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_user_id):
|
||||
"""set_turn_duration patches the cached session without invalidation.
|
||||
|
||||
Verifies that after calling set_turn_duration the Redis-cached session
|
||||
reflects the updated durationMs on the last assistant message, without
|
||||
the cache having been deleted and re-populated (which could race with
|
||||
concurrent get_chat_session calls).
|
||||
"""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi there"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Ensure the session is in cache
|
||||
cached = await get_chat_session(session.session_id, test_user_id)
|
||||
assert cached is not None
|
||||
assert cached.messages[-1].duration_ms is None
|
||||
|
||||
# Update turn duration — should patch cache in-place
|
||||
await set_turn_duration(session.session_id, 1234)
|
||||
|
||||
# Read from cache (not DB) — the cache should already have the update
|
||||
updated = await get_chat_session(session.session_id, test_user_id)
|
||||
assert updated is not None
|
||||
assistant_msgs = [m for m in updated.messages if m.role == "assistant"]
|
||||
assert len(assistant_msgs) == 1
|
||||
assert assistant_msgs[0].duration_ms == 1234
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user_id):
|
||||
"""set_turn_duration is a no-op when there are no assistant messages."""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Should not raise
|
||||
await set_turn_duration(session.session_id, 5678)
|
||||
|
||||
cached = await get_chat_session(session.session_id, test_user_id)
|
||||
assert cached is not None
|
||||
# User message should not have durationMs
|
||||
assert cached.messages[0].duration_ms is None
|
||||
@@ -13,7 +13,7 @@ import time
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@@ -30,6 +30,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
async def resolve_effective_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
) -> CopilotMode | None:
|
||||
"""Strip ``mode`` when the user is not entitled to the toggle.
|
||||
|
||||
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
|
||||
processor enforces the same gate server-side so an authenticated
|
||||
user cannot bypass the flag by crafting a request directly.
|
||||
"""
|
||||
if mode is None:
|
||||
return None
|
||||
allowed = await is_feature_enabled(
|
||||
Flag.CHAT_MODE_OPTION,
|
||||
user_id or "anonymous",
|
||||
default=False,
|
||||
)
|
||||
if not allowed:
|
||||
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
|
||||
return None
|
||||
return mode
|
||||
|
||||
|
||||
async def resolve_use_sdk_for_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
*,
|
||||
use_claude_code_subscription: bool,
|
||||
config_default: bool,
|
||||
) -> bool:
|
||||
"""Pick the SDK vs baseline path for a single turn.
|
||||
|
||||
Per-request ``mode`` wins whenever it is set (after the
|
||||
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
|
||||
falls back to the Claude Code subscription override, then the
|
||||
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return False
|
||||
if mode == "extended_thinking":
|
||||
return True
|
||||
return use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config_default,
|
||||
)
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
@@ -250,21 +301,26 @@ class CoPilotProcessor:
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
effective_mode = None
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
# Enforce server-side feature-flag gate so unauthorised
|
||||
# users cannot force a mode by crafting the request.
|
||||
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
|
||||
use_sdk = await resolve_use_sdk_for_mode(
|
||||
effective_mode,
|
||||
entry.user_id,
|
||||
use_claude_code_subscription=config.use_claude_code_subscription,
|
||||
config_default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={effective_mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
@@ -276,6 +332,7 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Unit tests for CoPilot mode routing logic in the processor.
|
||||
|
||||
Tests cover the mode→service mapping:
|
||||
- 'fast' → baseline service
|
||||
- 'extended_thinking' → SDK service
|
||||
- None → feature flag / config fallback
|
||||
|
||||
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
|
||||
the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import (
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveUseSdkForMode:
|
||||
"""Tests for the per-request mode routing logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_mode_uses_baseline(self):
|
||||
"""mode='fast' always routes to baseline, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"fast",
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=True,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_thinking_uses_sdk(self):
|
||||
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"extended_thinking",
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_subscription_override(self):
|
||||
"""mode=None with claude_code_subscription=True routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_feature_flag(self):
|
||||
"""mode=None with feature flag enabled routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
) as flag_mock:
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_config_default(self):
|
||||
"""mode=None falls back to config.use_claude_agent_sdk."""
|
||||
# When LaunchDarkly returns the default (True), we expect SDK routing.
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=True,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_all_disabled(self):
|
||||
"""mode=None with all flags off routes to baseline."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestResolveEffectiveMode:
|
||||
"""Tests for the CHAT_MODE_OPTION server-side gate."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_passes_through(self):
|
||||
"""mode=None is returned as-is without a flag check."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode(None, "user-1") is None
|
||||
flag_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_stripped_when_flag_disabled(self):
|
||||
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") is None
|
||||
assert await resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_preserved_when_flag_enabled(self):
|
||||
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert (
|
||||
await resolve_effective_mode("extended_thinking", "user-1")
|
||||
== "extended_thinking"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_user_with_mode(self):
|
||||
"""Anonymous users (user_id=None) still pass through the gate."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
123
autogpt_platform/backend/backend/copilot/executor/utils_test.py
Normal file
123
autogpt_platform/backend/backend/copilot/executor/utils_test.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
|
||||
|
||||
from backend.copilot.executor.utils import (
|
||||
COPILOT_EXECUTION_EXCHANGE,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_ROUTING_KEY,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
CoPilotLogMetadata,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
|
||||
class TestCoPilotExecutionEntry:
|
||||
def test_basic_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
)
|
||||
assert entry.session_id == "s1"
|
||||
assert entry.user_id == "u1"
|
||||
assert entry.message == "hello"
|
||||
assert entry.is_user_message is True
|
||||
assert entry.mode is None
|
||||
assert entry.context is None
|
||||
assert entry.file_ids is None
|
||||
|
||||
def test_mode_field(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="fast",
|
||||
)
|
||||
assert entry.mode == "fast"
|
||||
|
||||
entry2 = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="extended_thinking",
|
||||
)
|
||||
assert entry2.mode == "extended_thinking"
|
||||
|
||||
def test_optional_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
turn_id="t1",
|
||||
context={"url": "https://example.com"},
|
||||
file_ids=["f1", "f2"],
|
||||
is_user_message=False,
|
||||
)
|
||||
assert entry.turn_id == "t1"
|
||||
assert entry.context == {"url": "https://example.com"}
|
||||
assert entry.file_ids == ["f1", "f2"]
|
||||
assert entry.is_user_message is False
|
||||
|
||||
def test_serialization_roundtrip(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
mode="fast",
|
||||
)
|
||||
json_str = entry.model_dump_json()
|
||||
restored = CoPilotExecutionEntry.model_validate_json(json_str)
|
||||
assert restored == entry
|
||||
|
||||
|
||||
class TestCancelCoPilotEvent:
|
||||
def test_basic(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
assert event.session_id == "s1"
|
||||
|
||||
def test_serialization(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
|
||||
assert restored.session_id == "s1"
|
||||
|
||||
|
||||
class TestCreateCopilotQueueConfig:
|
||||
def test_returns_valid_config(self):
|
||||
config = create_copilot_queue_config()
|
||||
assert len(config.exchanges) == 2
|
||||
assert len(config.queues) == 2
|
||||
|
||||
def test_execution_queue_properties(self):
|
||||
config = create_copilot_queue_config()
|
||||
exec_queue = next(
|
||||
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert exec_queue.durable is True
|
||||
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
|
||||
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
|
||||
|
||||
def test_cancel_queue_uses_fanout(self):
|
||||
config = create_copilot_queue_config()
|
||||
cancel_queue = next(
|
||||
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert cancel_queue.exchange is not None
|
||||
assert cancel_queue.exchange.type.value == "fanout"
|
||||
|
||||
|
||||
class TestCoPilotLogMetadata:
|
||||
def test_creates_logger_with_metadata(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
|
||||
assert log is not None
|
||||
|
||||
def test_filters_none_values(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(
|
||||
base_logger, session_id="s1", user_id=None, turn_id="t1"
|
||||
)
|
||||
assert log is not None
|
||||
@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
# GitHub user identity caches (keyed by user_id only, not provider tuple).
|
||||
# Declared here so invalidate_user_provider_cache() can reference them.
|
||||
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
|
||||
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
|
||||
)
|
||||
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
|
||||
For GitHub specifically, also clears the git-identity caches so that
|
||||
``get_github_user_git_identity()`` re-fetches the user's profile on
|
||||
the next call instead of serving stale identity data.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
if provider == "github":
|
||||
_gh_identity_cache.pop(user_id, None)
|
||||
_gh_identity_null_cache.pop(user_id, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GitHub user identity (for git committer env vars)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
|
||||
"""Fetch the GitHub user's name and email for git committer env vars.
|
||||
|
||||
Uses the ``/user`` GitHub API endpoint with the user's stored token.
|
||||
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
|
||||
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
|
||||
connected GitHub account. Returns ``None`` otherwise.
|
||||
|
||||
Results are cached for 10 minutes; "not connected" results are cached for
|
||||
60 s (same as null-token cache).
|
||||
"""
|
||||
if user_id in _gh_identity_null_cache:
|
||||
return None
|
||||
if cached := _gh_identity_cache.get(user_id):
|
||||
return cached
|
||||
|
||||
token = await get_provider_token(user_id, "github")
|
||||
if not token:
|
||||
_gh_identity_null_cache[user_id] = True
|
||||
return None
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(
|
||||
"[git-identity] GitHub /user returned %s for user %s",
|
||||
resp.status,
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
data = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
name = data.get("name") or data.get("login") or "AutoGPT User"
|
||||
# GitHub may return email=null if the user has set their email to private.
|
||||
# Fall back to the noreply address GitHub generates for every account.
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
gh_id = data.get("id", "")
|
||||
login = data.get("login", "user")
|
||||
email = f"{gh_id}+{login}@users.noreply.github.com"
|
||||
|
||||
identity = {
|
||||
"GIT_AUTHOR_NAME": name,
|
||||
"GIT_AUTHOR_EMAIL": email,
|
||||
"GIT_COMMITTER_NAME": name,
|
||||
"GIT_COMMITTER_EMAIL": email,
|
||||
}
|
||||
_gh_identity_cache[user_id] = identity
|
||||
return identity
|
||||
|
||||
@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_gh_identity_cache,
|
||||
_gh_identity_null_cache,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
@@ -49,9 +51,13 @@ def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
def test_clears_gh_identity_cache_for_github_provider(self):
|
||||
"""When provider is 'github', identity caches must also be cleared."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Old Name",
|
||||
"GIT_AUTHOR_EMAIL": "old@example.com",
|
||||
"GIT_COMMITTER_NAME": "Old Name",
|
||||
"GIT_COMMITTER_EMAIL": "old@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_cache
|
||||
|
||||
def test_clears_gh_identity_null_cache_for_github_provider(self):
|
||||
"""When provider is 'github', the identity null-cache must also be cleared."""
|
||||
_gh_identity_null_cache[_USER] = True
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_null_cache
|
||||
|
||||
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
|
||||
"""When provider is NOT 'github', identity caches must be left alone."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Some Name",
|
||||
"GIT_AUTHOR_EMAIL": "some@example.com",
|
||||
"GIT_COMMITTER_NAME": "Some Name",
|
||||
"GIT_COMMITTER_EMAIL": "some@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "some-other-provider")
|
||||
assert _USER in _gh_identity_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -129,8 +163,15 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -141,6 +182,8 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -176,6 +219,96 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -46,6 +46,16 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatSessionMetadata(BaseModel):
|
||||
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
|
||||
|
||||
Add new session-level flags here instead of adding DB columns —
|
||||
no migration required for new fields as long as a default is provided.
|
||||
"""
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -71,6 +81,49 @@ class ChatMessage(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def is_message_duplicate(
|
||||
messages: list[ChatMessage],
|
||||
role: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Check whether *content* is already present in the current pending turn.
|
||||
|
||||
Only inspects trailing messages that share the given *role* (i.e. the
|
||||
current turn). This ensures legitimately repeated messages across different
|
||||
turns are not suppressed, while same-turn duplicates from stale cache are
|
||||
still caught.
|
||||
"""
|
||||
for m in reversed(messages):
|
||||
if m.role == role:
|
||||
if m.content == content:
|
||||
return True
|
||||
else:
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def maybe_append_user_message(
|
||||
session: "ChatSession",
|
||||
message: str | None,
|
||||
is_user_message: bool,
|
||||
) -> bool:
|
||||
"""Append a user/assistant message to the session if not already present.
|
||||
|
||||
The route handler already persists the user message before enqueueing,
|
||||
so we check trailing same-role messages to avoid re-appending when the
|
||||
session cache is slightly stale.
|
||||
|
||||
Returns True if the message was appended, False if skipped.
|
||||
"""
|
||||
if not message:
|
||||
return False
|
||||
role = "user" if is_user_message else "assistant"
|
||||
if is_message_duplicate(session.messages, role, message):
|
||||
return False
|
||||
session.messages.append(ChatMessage(role=role, content=message))
|
||||
return True
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
@@ -90,6 +143,12 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
@property
|
||||
def dry_run(self) -> bool:
|
||||
"""Convenience accessor for ``metadata.dry_run``."""
|
||||
return self.metadata.dry_run
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
@@ -103,6 +162,10 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Parse typed metadata from the JSON column.
|
||||
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
|
||||
metadata = ChatSessionMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
@@ -128,6 +191,7 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +199,7 @@ class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
def new(cls, user_id: str, *, dry_run: bool) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -145,6 +209,7 @@ class ChatSession(ChatSessionInfo):
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metadata=ChatSessionMetadata(dry_run=dry_run),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -532,6 +597,7 @@ async def _save_session_to_db(
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
@@ -609,21 +675,27 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID.
|
||||
dry_run: When True, run_block and run_agent tool calls in this
|
||||
session are forced to use dry-run simulation mode.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
session = ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
|
||||
@@ -17,6 +17,8 @@ from .model import (
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
@@ -46,7 +48,7 @@ messages = [
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(user_id="abc123", dry_run=False)
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
@@ -57,7 +59,7 @@ async def test_chatsession_serialization_deserialization():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -75,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
@@ -90,7 +92,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
@@ -241,7 +243,7 @@ _raw_tc2 = {
|
||||
|
||||
def test_add_tool_call_appends_to_existing_assistant():
|
||||
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="working on it"),
|
||||
@@ -254,7 +256,7 @@ def test_add_tool_call_appends_to_existing_assistant():
|
||||
|
||||
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
"""When there's no current-turn assistant, a new one is created."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
]
|
||||
@@ -267,7 +269,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
|
||||
|
||||
def test_add_tool_call_does_not_cross_user_boundary():
|
||||
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="old turn"),
|
||||
ChatMessage(role="user", content="new message"),
|
||||
@@ -282,7 +284,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
|
||||
|
||||
def test_add_tool_call_multiple_times():
|
||||
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="doing stuff"),
|
||||
@@ -300,7 +302,7 @@ def test_add_tool_call_multiple_times():
|
||||
|
||||
def test_to_openai_messages_merges_split_assistants():
|
||||
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||
session = ChatSession.new(user_id="u")
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="build agent"),
|
||||
ChatMessage(role="assistant", content="Let me build that"),
|
||||
@@ -352,7 +354,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
@@ -424,3 +426,151 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# is_message_duplicate #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_duplicate_detected_in_trailing_same_role():
|
||||
"""Duplicate user message at the tail is detected."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi there"),
|
||||
ChatMessage(role="user", content="yes"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "yes") is True
|
||||
|
||||
|
||||
def test_duplicate_not_detected_across_turns():
|
||||
"""Same text in a previous turn (separated by assistant) is NOT a duplicate."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="yes"),
|
||||
ChatMessage(role="assistant", content="ok"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "yes") is False
|
||||
|
||||
|
||||
def test_no_duplicate_on_empty_messages():
|
||||
"""Empty message list never reports a duplicate."""
|
||||
assert is_message_duplicate([], "user", "hello") is False
|
||||
|
||||
|
||||
def test_no_duplicate_when_content_differs():
|
||||
"""Different content in the trailing same-role block is not a duplicate."""
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content="response"),
|
||||
ChatMessage(role="user", content="first message"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "second message") is False
|
||||
|
||||
|
||||
def test_duplicate_with_multiple_trailing_same_role():
|
||||
"""Detects duplicate among multiple consecutive same-role messages."""
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content="response"),
|
||||
ChatMessage(role="user", content="msg1"),
|
||||
ChatMessage(role="user", content="msg2"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "msg1") is True
|
||||
assert is_message_duplicate(msgs, "user", "msg2") is True
|
||||
assert is_message_duplicate(msgs, "user", "msg3") is False
|
||||
|
||||
|
||||
def test_duplicate_check_for_assistant_role():
|
||||
"""Works correctly when checking assistant role too."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="assistant", content="how can I help?"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "assistant", "hello") is True
|
||||
assert is_message_duplicate(msgs, "assistant", "new response") is False
|
||||
|
||||
|
||||
def test_no_false_positive_when_content_is_none():
|
||||
"""Messages with content=None in the trailing block do not match."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content=None),
|
||||
ChatMessage(role="user", content="hello"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "hello") is True
|
||||
# None-content message should not match any string
|
||||
msgs2 = [
|
||||
ChatMessage(role="user", content=None),
|
||||
]
|
||||
assert is_message_duplicate(msgs2, "user", "hello") is False
|
||||
|
||||
|
||||
def test_all_same_role_messages():
|
||||
"""When all messages share the same role, the entire list is scanned."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="first"),
|
||||
ChatMessage(role="user", content="second"),
|
||||
ChatMessage(role="user", content="third"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "first") is True
|
||||
assert is_message_duplicate(msgs, "user", "new") is False
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# maybe_append_user_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_maybe_append_user_message_appends_new():
|
||||
"""A new user message is appended and returns True."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "new msg", is_user_message=True)
|
||||
assert result is True
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[-1].role == "user"
|
||||
assert session.messages[-1].content == "new msg"
|
||||
|
||||
|
||||
def test_maybe_append_user_message_skips_duplicate():
|
||||
"""A duplicate user message is skipped and returns False."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="user", content="dup"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=True)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
def test_maybe_append_user_message_none_message():
|
||||
"""None/empty message returns False without appending."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
assert maybe_append_user_message(session, None, is_user_message=True) is False
|
||||
assert maybe_append_user_message(session, "", is_user_message=True) is False
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
def test_maybe_append_assistant_message():
|
||||
"""Works for assistant role when is_user_message=False."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "response", is_user_message=False)
|
||||
assert result is True
|
||||
assert session.messages[-1].role == "assistant"
|
||||
assert session.messages[-1].content == "response"
|
||||
|
||||
|
||||
def test_maybe_append_assistant_skips_duplicate():
|
||||
"""Duplicate assistant message is skipped."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="dup"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
@@ -102,6 +103,7 @@ ToolName = Literal[
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Agent",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
|
||||
@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
|
||||
@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
### Handling binary/image data in tool outputs — CRITICAL
|
||||
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
|
||||
1. **NEVER** try to inline or render the base64 content in your response.
|
||||
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
|
||||
3. **Show** the result via the workspace download URL in Markdown: ``.
|
||||
|
||||
### Passing large data between tools — CRITICAL
|
||||
When tool outputs produce large text that you need to feed into another tool:
|
||||
- **NEVER** copy-paste the full text into the next tool call argument.
|
||||
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
|
||||
- This avoids token limits and ensures data integrity.
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
@@ -107,6 +119,28 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Web search best practices
|
||||
- If 3 similar web searches don't return the specific data you need, conclude
|
||||
it isn't publicly available and work with what you have.
|
||||
- Prefer fewer, well-targeted searches over many variations of the same query.
|
||||
- When spawning sub-agents for research, ensure each has a distinct
|
||||
non-overlapping scope to avoid redundant searches.
|
||||
|
||||
|
||||
### Tool Discovery Priority
|
||||
|
||||
When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
|
||||
|
||||
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
|
||||
|
||||
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
|
||||
|
||||
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -131,6 +165,11 @@ parent autopilot handles orchestration.
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### SDK tool-result files in E2B
|
||||
When you `Read` an SDK tool-result file, it is automatically copied into the
|
||||
sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
@@ -196,19 +235,22 @@ def _build_storage_supplement(
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
|
||||
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
|
||||
To read these files, use `Read` — it reads from the host filesystem.
|
||||
|
||||
### Large tool outputs saved to workspace
|
||||
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
|
||||
full output is in workspace storage (NOT on the local filesystem). To access it:
|
||||
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
|
||||
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
28
autogpt_platform/backend/backend/copilot/prompting_test.py
Normal file
28
autogpt_platform/backend/backend/copilot/prompting_test.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Tests for agent generation guide — verifies clarification section."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
@@ -9,11 +9,14 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +24,40 @@ logger = logging.getLogger(__name__)
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscription tier definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tiers with increasing token allowances.
|
||||
|
||||
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
|
||||
Once ``prisma generate`` is run, this can be replaced with::
|
||||
|
||||
from prisma.enums import SubscriptionTier
|
||||
"""
|
||||
|
||||
FREE = "FREE"
|
||||
PRO = "PRO"
|
||||
BUSINESS = "BUSINESS"
|
||||
ENTERPRISE = "ENTERPRISE"
|
||||
|
||||
|
||||
# Multiplier applied to the base limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole token counts and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# the type and round the result in get_global_rate_limits().
|
||||
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.FREE: 1,
|
||||
SubscriptionTier.PRO: 5,
|
||||
SubscriptionTier.BUSINESS: 20,
|
||||
SubscriptionTier.ENTERPRISE: 60,
|
||||
}
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
@@ -36,6 +73,7 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
tier: SubscriptionTier = DEFAULT_TIER
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
@@ -66,6 +104,7 @@ async def get_usage_status(
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
tier: SubscriptionTier = DEFAULT_TIER,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -74,6 +113,7 @@ async def get_usage_status(
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
tier: The user's rate-limit tier (included in the response).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -103,6 +143,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
tier=tier,
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
@@ -161,8 +202,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
Returns False if Redis is unavailable so the caller can handle
|
||||
compensation (fail-closed for billed operations, unlike the read-only
|
||||
rate-limit checks which fail-open).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
@@ -342,20 +384,100 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
async def get_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Look up the user's rate-limit tier from the database.
|
||||
|
||||
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
|
||||
to avoid a DB round-trip on every rate-limit check.
|
||||
|
||||
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
|
||||
unreachable or returns an unrecognised value, so the next call retries
|
||||
the query instead of serving a stale fallback for up to 5 minutes.
|
||||
"""
|
||||
try:
|
||||
return await _fetch_user_tier(user_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
|
||||
user_id[:8],
|
||||
DEFAULT_TIER.value,
|
||||
exc,
|
||||
)
|
||||
return DEFAULT_TIER
|
||||
|
||||
|
||||
# Expose cache management on the public function so callers (including tests)
|
||||
# never need to reach into the private ``_fetch_user_tier``.
|
||||
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
"""
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, SubscriptionTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
tier multiplier so that higher tiers receive proportionally larger
|
||||
allowances.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
@@ -377,7 +499,15 @@ async def get_global_rate_limits(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
# Apply tier multiplier
|
||||
tier = await get_user_tier(user_id)
|
||||
multiplier = TIER_MULTIPLIERS.get(tier, 1)
|
||||
if multiplier != 1:
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
return daily, weekly, tier
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_rate_limits(
|
||||
daily: int = 2_500_000,
|
||||
weekly: int = 12_500_000,
|
||||
tier: SubscriptionTier = SubscriptionTier.PRO,
|
||||
):
|
||||
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
|
||||
return patch(
|
||||
f"{_MODULE}.get_global_rate_limits",
|
||||
AsyncMock(return_value=(daily, weekly, tier)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
@@ -70,6 +82,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(daily=0),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
@@ -83,6 +96,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -112,6 +126,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -141,6 +156,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -171,6 +187,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -208,6 +225,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -228,6 +246,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -245,6 +264,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -275,6 +295,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
|
||||
@@ -3,26 +3,62 @@
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Clarifying — Before or During Building
|
||||
|
||||
Use `ask_question` whenever the user's intent is ambiguous — whether
|
||||
that's before starting or midway through the workflow. Common moments:
|
||||
|
||||
- **Before building**: output format, delivery channel, data source, or
|
||||
trigger is unspecified.
|
||||
- **During block discovery**: multiple blocks could fit and the user
|
||||
should choose.
|
||||
- **During JSON generation**: a wiring decision depends on user
|
||||
preference.
|
||||
|
||||
Steps:
|
||||
1. Call `find_block` (or another discovery tool) to learn what the
|
||||
platform actually supports for the ambiguous dimension.
|
||||
2. Call `ask_question` with a concrete question listing the discovered
|
||||
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
|
||||
which should the agent use for delivery?").
|
||||
3. **Wait for the user's answer** before continuing.
|
||||
|
||||
**Skip this** when the goal already specifies all dimensions (e.g.
|
||||
"scrape prices from Amazon and email me daily").
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
|
||||
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
|
||||
returns the full graph structure (nodes + links). **Never edit blindly** —
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
- Use block IDs from step 2 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
- When editing, apply targeted changes and preserve unchanged parts
|
||||
5. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
|
||||
`wait_for_result=120` to verify the agent works end-to-end.
|
||||
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
|
||||
found, call `edit_agent` to fix and dry-run again. Repeat until the
|
||||
simulation passes or the problems are clearly unfixable.
|
||||
See "REQUIRED: Dry-Run Verification Loop" section below for details.
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
@@ -74,8 +110,8 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
@@ -216,19 +252,62 @@ call in a loop until the task is complete:
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
After creating or editing an agent, you MUST dry-run it before telling the
|
||||
user the agent is ready. NEVER skip this step.
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
#### Step-by-step workflow
|
||||
|
||||
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
|
||||
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
|
||||
and realistic sample inputs that exercise every path in the agent. This
|
||||
simulates execution using an LLM for each block — no real API calls,
|
||||
credentials, or credits are consumed.
|
||||
3. **Inspect output**: Examine the dry-run result for problems. If
|
||||
`wait_for_result` returns only a summary, call
|
||||
`view_agent_output(execution_id=..., show_execution_details=True)` to
|
||||
see the full node-by-node execution trace. Look for:
|
||||
- **Errors / failed nodes** — a node raised an exception or returned an
|
||||
error status. Common causes: wrong `source_name`/`sink_name` in links,
|
||||
missing `input_default` values, or referencing a nonexistent block output.
|
||||
- **Null / empty outputs** — data did not flow through a link. Verify that
|
||||
`source_name` and `sink_name` match the block schemas exactly (case-
|
||||
sensitive, including nested `_#_` notation).
|
||||
- **Nodes that never executed** — the node was not reached. Likely a
|
||||
missing or broken link from an upstream node.
|
||||
- **Unexpected values** — data arrived but in the wrong type or
|
||||
structure. Check type compatibility between linked ports.
|
||||
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
|
||||
agent JSON, then go back to step 2.
|
||||
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
|
||||
or the problems are clearly unfixable. If you stop making progress,
|
||||
report the remaining issues to the user and ask for guidance.
|
||||
|
||||
#### Good vs bad dry-run output
|
||||
|
||||
**Good output** (agent is ready):
|
||||
- All nodes executed successfully (no errors in the execution trace)
|
||||
- Data flows through every link with non-null, correctly-typed values
|
||||
- The final `AgentOutputBlock` contains a meaningful result
|
||||
- Status is `COMPLETED`
|
||||
|
||||
**Bad output** (needs fixing):
|
||||
- Status is `FAILED` — check the error message for the failing node
|
||||
- An output node received `null` — trace back to find the broken link
|
||||
- A node received data in the wrong format (e.g. string where list expected)
|
||||
- Nodes downstream of a failing node were skipped entirely
|
||||
|
||||
**Special block behaviour in dry-run mode:**
|
||||
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
|
||||
orchestrator can make LLM calls and agent executors can spawn child graphs.
|
||||
Their downstream tool blocks and child-graph blocks are still simulated.
|
||||
Note: real LLM inference calls are made (consuming API quota), even though
|
||||
platform credits are not charged. Agent-mode iterations are capped at 1 in
|
||||
dry-run to keep it fast.
|
||||
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
|
||||
so the LLM can produce a realistic mock response without connecting to the
|
||||
MCP server.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user")
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
|
||||
async def _server_noop() -> None:
|
||||
"""No-op server stub — SDK tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
|
||||
)
|
||||
async def _graph_cleanup_noop() -> AsyncIterator[None]:
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
|
||||
@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -28,6 +31,12 @@ from backend.copilot.context import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default number of lines returned by ``read_file`` when the caller does not
|
||||
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
|
||||
# decide whether the model is requesting the full file (and thus whether the
|
||||
# bridge copy is worthwhile).
|
||||
_DEFAULT_READ_LIMIT = 2000
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
|
||||
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
|
||||
are handled correctly: text is encoded to bytes for the base64 shell
|
||||
pipe, and raw bytes are passed through without any encoding.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
raw = content.encode() if isinstance(content, str) else content
|
||||
encoded = base64.b64encode(raw).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
|
||||
# stay on the host. When E2B is active, also copy the file into the
|
||||
# sandbox so bash_exec can access it for further processing.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
result = _read_local(file_path, offset, limit)
|
||||
if not result.get("isError"):
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, file_path, offset, limit
|
||||
)
|
||||
if annotation:
|
||||
result["content"][0]["text"] += annotation
|
||||
return result
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Bridging: copy SDK-internal files into E2B sandbox
|
||||
|
||||
# Files larger than this are written to /home/user/ via sandbox.files.write()
|
||||
# instead of /tmp/ via shell base64, to avoid shell argument length limits
|
||||
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
|
||||
# well under the typical Linux ARG_MAX (128 KB).
|
||||
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
|
||||
# Files larger than this are skipped entirely to avoid excessive transfer times.
|
||||
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
async def bridge_to_sandbox(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
|
||||
|
||||
When the model reads an SDK-internal file (e.g. tool-results), it often
|
||||
wants to process the data with bash. Copying the file into the sandbox
|
||||
under a stable name lets ``bash_exec`` access it without extra steps.
|
||||
|
||||
Only copies when offset=0 and limit is large enough to indicate the model
|
||||
wants the full file. Errors are logged but never propagated.
|
||||
|
||||
Returns the sandbox path on success, or ``None`` on skip/failure.
|
||||
|
||||
Size handling:
|
||||
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
|
||||
(``_sandbox_write``). Kept small to stay within ARG_MAX.
|
||||
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
|
||||
``sandbox.files.write()`` to avoid shell argument length limits.
|
||||
- > 50 MB: skipped entirely with a warning.
|
||||
|
||||
The sandbox filename is prefixed with a short hash of the full source
|
||||
path to avoid collisions when different source files share the same
|
||||
basename (e.g. multiple ``result.json`` files).
|
||||
"""
|
||||
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
|
||||
return None
|
||||
try:
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
unique_name = f"{source_id}-{basename}"
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _BRIDGE_SKIP_BYTES:
|
||||
logger.warning(
|
||||
"[E2B] Skipping bridge for large file (%d bytes): %s",
|
||||
file_size,
|
||||
basename,
|
||||
)
|
||||
return None
|
||||
|
||||
def _read_bytes() -> bytes:
|
||||
with open(expanded, "rb") as fh:
|
||||
return fh.read()
|
||||
|
||||
raw_content = await asyncio.to_thread(_read_bytes)
|
||||
try:
|
||||
text_content: str | None = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text_content = None
|
||||
data: str | bytes = text_content if text_content is not None else raw_content
|
||||
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
|
||||
sandbox_path = f"/tmp/{unique_name}"
|
||||
await _sandbox_write(sandbox, sandbox_path, data)
|
||||
else:
|
||||
sandbox_path = f"/home/user/{unique_name}"
|
||||
await sandbox.files.write(sandbox_path, data)
|
||||
logger.info(
|
||||
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
|
||||
)
|
||||
return sandbox_path
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to bridge SDK file to sandbox: %s",
|
||||
file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def bridge_and_annotate(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
|
||||
|
||||
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
|
||||
callers don't need to duplicate the pattern. Returns a string like
|
||||
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
|
||||
``None`` if bridging was skipped or failed.
|
||||
"""
|
||||
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
|
||||
if sandbox_path is None:
|
||||
return None
|
||||
return f"\n[Sandbox copy available at {sandbox_path}]"
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
@@ -13,12 +14,26 @@ import pytest
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_BRIDGE_SHELL_MAX_BYTES,
|
||||
_BRIDGE_SKIP_BYTES,
|
||||
_DEFAULT_READ_LIMIT,
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
bridge_and_annotate,
|
||||
bridge_to_sandbox,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
|
||||
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
|
||||
"""Compute the expected sandbox path for a bridged file."""
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
return f"{prefix}/{source_id}-{basename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
|
||||
# (via is_allowed_local_path without sdk_cwd). Regular files live on
|
||||
# the sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -119,7 +134,7 @@ class TestReadLocal:
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
@@ -127,6 +142,25 @@ class TestReadLocal:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_tool_outputs_file(self):
|
||||
"""Reading a tool-outputs file should also succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read-outputs"
|
||||
tool_outputs_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
|
||||
)
|
||||
os.makedirs(tool_outputs_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
|
||||
with open(filepath, "w") as f:
|
||||
f.write('{"data": "test"}\n')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "test" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
@@ -335,3 +369,199 @@ class TestSandboxWrite:
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bridge_sandbox() -> SimpleNamespace:
|
||||
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
return SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
|
||||
class TestBridgeToSandbox:
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_small_file(self, tmp_path):
|
||||
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
|
||||
f = tmp_path / "result.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "result.json" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_offset_nonzero(self, tmp_path):
|
||||
"""Bridging is skipped when offset != 0 (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_limit_too_small(self, tmp_path):
|
||||
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_file_does_not_raise(self, tmp_path):
|
||||
"""Bridging a non-existent file logs but does not propagate errors."""
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(
|
||||
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_write_failure_returns_none(self, tmp_path):
|
||||
"""If sandbox write fails, returns None (best-effort)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_file_uses_files_api(self, tmp_path):
|
||||
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
|
||||
f = tmp_path / "big.json"
|
||||
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_binary_file_preserves_bytes(self, tmp_path):
|
||||
"""A small binary file is bridged to /tmp via base64 without corruption."""
|
||||
binary_data = bytes(range(256))
|
||||
f = tmp_path / "image.png"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "base64" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
|
||||
"""A large binary file is bridged to /home/user/ as raw bytes."""
|
||||
binary_data = bytes(range(256)) * 200
|
||||
f = tmp_path / "photo.jpg"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
assert call_args[1] == binary_data
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_file_skipped(self, tmp_path):
|
||||
"""Files > 50 MB are skipped entirely."""
|
||||
f = tmp_path / "huge.bin"
|
||||
# Create a sparse file to avoid actually writing 50 MB
|
||||
with open(f, "wb") as fh:
|
||||
fh.seek(_BRIDGE_SKIP_BYTES + 1)
|
||||
fh.write(b"\0")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBridgeAndAnnotate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_annotation_on_success(self, tmp_path):
|
||||
"""On success, returns a newline-prefixed annotation with the sandbox path."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected_path = _expected_bridge_path(str(f))
|
||||
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_skipped(self, tmp_path):
|
||||
"""When bridging is skipped (e.g. offset != 0), returns None."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert annotation is None
|
||||
|
||||
@@ -275,7 +275,7 @@ class TestCompactionE2E:
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
session = ChatSession.new(user_id="test-user", dry_run=False)
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
@@ -376,7 +376,7 @@ class TestCompactionE2E:
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
session = ChatSession.new(user_id="test", dry_run=False)
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
|
||||
82
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
82
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""SDK environment variable builder — importable without circular deps.
|
||||
|
||||
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
|
||||
can reuse the same subscription / OpenRouter / direct-Anthropic logic
|
||||
without pulling in the full copilot service module (which would create a
|
||||
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.sdk.subscription import validate_subscription
|
||||
|
||||
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
|
||||
# A singleton would require importing service.py which causes the circular dep
|
||||
# this module was created to avoid.
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
|
||||
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
|
||||
the CLI writes temp/sub-agent output inside the per-session workspace
|
||||
directory rather than an inaccessible system temp path.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
env = {}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
|
||||
return env
|
||||
293
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
293
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — build a ChatConfig with explicit field values so tests don't
|
||||
# depend on real environment variables.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 1 — Subscription auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvSubscription:
|
||||
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_returns_blanked_keys(self, mock_validate):
|
||||
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"backend.copilot.sdk.env.validate_subscription",
|
||||
side_effect=RuntimeError("CLI not found"),
|
||||
)
|
||||
def test_propagates_validation_error(self, mock_validate):
|
||||
"""If validate_subscription fails, the error bubbles up."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
with pytest.raises(RuntimeError, match="CLI not found"):
|
||||
build_sdk_env()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 2 — Direct Anthropic (no OpenRouter)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvDirectAnthropic:
|
||||
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_inactive(self):
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
|
||||
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
|
||||
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
|
||||
# Force api_key to None after construction (field_validator may pick up env vars)
|
||||
object.__setattr__(cfg, "api_key", None)
|
||||
assert not cfg.openrouter_active
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 3 — OpenRouter proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvOpenRouter:
|
||||
"""When OpenRouter is active, return proxy env vars."""
|
||||
|
||||
def _openrouter_config(self, **overrides):
|
||||
defaults = {
|
||||
"use_openrouter": True,
|
||||
"api_key": "sk-or-test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return _make_config(**defaults)
|
||||
|
||||
def test_basic_openrouter_env(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
|
||||
|
||||
def test_strips_trailing_v1(self):
|
||||
"""The /v1 suffix is stripped from the base URL."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_strips_trailing_v1_and_slash(self):
|
||||
"""Trailing slash before /v1 strip is handled."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# rstrip("/") first, then remove /v1
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_no_v1_suffix_left_alone(self):
|
||||
"""A base URL without /v1 is used as-is."""
|
||||
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
|
||||
|
||||
def test_session_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="sess-123")
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" in result
|
||||
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_user_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(user_id="user-456")
|
||||
|
||||
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_both_headers(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="s1", user_id="u2")
|
||||
|
||||
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-session-id: s1" in headers
|
||||
assert "x-user-id: u2" in headers
|
||||
# They should be newline-separated
|
||||
assert "\n" in headers
|
||||
|
||||
def test_header_sanitisation_strips_newlines(self):
|
||||
"""Newlines/carriage-returns in header values are stripped."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="bad\r\nvalue")
|
||||
|
||||
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
# The _safe helper removes \r and \n
|
||||
assert "\r" not in header_val.split(": ", 1)[1]
|
||||
assert "badvalue" in header_val
|
||||
|
||||
def test_header_value_truncated_to_128_chars(self):
|
||||
"""Header values are truncated to 128 characters."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
long_id = "x" * 200
|
||||
result = build_sdk_env(session_id=long_id)
|
||||
|
||||
# The value after "x-session-id: " should be at most 128 chars
|
||||
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
value = header_line.split(": ", 1)[1]
|
||||
assert len(value) == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvModePriority:
|
||||
"""Subscription mode takes precedence over OpenRouter."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_subscription_overrides_openrouter(self, mock_validate):
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=True,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# Should get subscription result, not OpenRouter
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLAUDE_CODE_TMPDIR integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaudeCodeTmpdir:
|
||||
"""Verify build_sdk_env() sets CLAUDE_CODE_TMPDIR from *sdk_cwd*."""
|
||||
|
||||
def test_tmpdir_set_when_sdk_cwd_is_truthy(self):
|
||||
"""CLAUDE_CODE_TMPDIR is set to sdk_cwd when sdk_cwd is truthy."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/copilot-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/copilot-workspace"
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_none(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is None."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd=None)
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_empty_string(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is empty string."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="")
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_tmpdir_set_in_subscription_mode(self, mock_validate):
|
||||
"""CLAUDE_CODE_TMPDIR is set even in subscription mode."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/sub-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/sub-workspace"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
@@ -28,13 +28,12 @@ Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
Always follow the **Tool Discovery Priority** described in the tool notes:
|
||||
call `find_block` before resorting to `run_mcp_tool`.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
- You searched `find_block` first and found no matching blocks, AND
|
||||
- The service is in the known hosted MCP servers list above or found via the registry API
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
|
||||
@@ -8,20 +8,19 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
@@ -38,7 +37,7 @@ class TestFlattenAssistantContent:
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
assert _flatten_assistant_content(blocks) == ""
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
@@ -47,19 +46,22 @@ class TestFlattenAssistantContent:
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
def test_unknown_block_type_dropped(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
# Unknown block types are dropped to prevent model mimicry
|
||||
assert "[__image__]" not in result
|
||||
assert "base64" not in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
@@ -279,7 +281,8 @@ class TestTranscriptToMessages:
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "read_file" not in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
@@ -399,7 +402,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -434,7 +437,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -458,7 +461,7 @@ class TestCompactTranscript:
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
@@ -564,11 +567,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
@@ -598,11 +601,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
|
||||
@@ -49,22 +49,22 @@ def test_format_assistant_tool_calls():
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
# Assistant with no content and tool_calls omitted produces no lines
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
assert 'Tool output: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
assert "Tool output: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
@@ -84,8 +84,8 @@ def test_format_full_conversation():
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
# tool_calls are omitted to prevent model mimicry
|
||||
assert "Tool output:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -76,6 +77,12 @@ class SDKResponseAdapter:
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
elif sdk_message.subtype == "task_progress":
|
||||
# Emit a heartbeat so publish_chunk is called during long
|
||||
# sub-agent runs. Without this, the Redis stream and meta
|
||||
# key TTLs expire during gaps where no real chunks are
|
||||
# produced (task_progress events were previously silent).
|
||||
responses.append(StreamHeartbeat())
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -28,6 +29,7 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .compaction import compaction_events
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
@@ -59,6 +61,14 @@ def test_system_non_init_emits_nothing():
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_task_progress_emits_heartbeat():
|
||||
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamHeartbeat)
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
@@ -680,3 +690,102 @@ def test_already_resolved_tool_skipped_in_user_message():
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
|
||||
# -- _end_text_if_open before compaction -------------------------------------
|
||||
|
||||
|
||||
def test_end_text_if_open_emits_text_end_before_finish_step():
|
||||
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
|
||||
|
||||
When ``emit_end_if_ready`` fires compaction events while a text block is
|
||||
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
|
||||
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
|
||||
and raises "Received text-end for missing text part".
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a text block by processing an AssistantMessage with text
|
||||
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.has_started_text
|
||||
assert not adapter.has_ended_text
|
||||
|
||||
# Simulate what service.py does before yielding compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
combined = pre_close + list(compaction_events("Compacted transcript"))
|
||||
|
||||
text_end_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
|
||||
)
|
||||
finish_step_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
|
||||
)
|
||||
|
||||
assert text_end_idx is not None, "StreamTextEnd must be present"
|
||||
assert finish_step_idx is not None, "StreamFinishStep must be present"
|
||||
assert text_end_idx < finish_step_idx, (
|
||||
f"StreamTextEnd (idx={text_end_idx}) must precede "
|
||||
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
|
||||
"clears activeTextParts before text-end arrives"
|
||||
)
|
||||
|
||||
|
||||
def test_step_open_must_reset_after_compaction_finish_step():
|
||||
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
|
||||
|
||||
Compaction events bypass the adapter, so service.py must explicitly clear
|
||||
step_open after yielding a StreamFinishStep from compaction. Without this,
|
||||
the next AssistantMessage skips StreamStartStep because the adapter still
|
||||
thinks a step is open.
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a step + text block via an AssistantMessage
|
||||
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.step_open is True
|
||||
|
||||
# Simulate what service.py does: close text, then check compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
|
||||
events = list(compaction_events("Compacted transcript"))
|
||||
if any(isinstance(ev, StreamFinishStep) for ev in events):
|
||||
adapter.step_open = False
|
||||
|
||||
assert (
|
||||
adapter.step_open is False
|
||||
), "step_open must be False after compaction emits StreamFinishStep"
|
||||
|
||||
# Next AssistantMessage must open a new step
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
|
||||
results = adapter.convert_message(msg2)
|
||||
assert any(
|
||||
isinstance(r, StreamStartStep) for r in results
|
||||
), "A new StreamStartStep must be emitted after compaction closed the step"
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_when_no_text_open():
|
||||
"""_end_text_if_open emits nothing when no text block is open."""
|
||||
adapter = _adapter()
|
||||
results: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(results)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
"""_end_text_if_open emits nothing when the text block is already closed."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
# Close it once
|
||||
first: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(first)
|
||||
assert len(first) == 1
|
||||
assert isinstance(first[0], StreamTextEnd)
|
||||
# Second call must be a no-op
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -113,7 +112,7 @@ class TestScenarioCompactAndRetry:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -170,7 +169,7 @@ class TestScenarioCompactFailsFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
@@ -261,7 +260,7 @@ class TestScenarioDoubleFailDBFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -337,7 +336,7 @@ class TestScenarioCompactionIdentical:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -730,7 +729,7 @@ class TestRetryEdgeCases:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -841,7 +840,7 @@ class TestRetryStateReset:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
@@ -904,14 +903,14 @@ class TestTranscriptEdgeCases:
|
||||
assert restored[1]["content"] == "Second"
|
||||
|
||||
def test_flatten_assistant_with_only_tool_use(self):
|
||||
"""Assistant message with only tool_use blocks (no text)."""
|
||||
"""Assistant message with only tool_use blocks (no text) flattens to empty."""
|
||||
blocks = [
|
||||
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
|
||||
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "[tool_use: bash]" in result
|
||||
assert "[tool_use: read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert result == ""
|
||||
|
||||
def test_flatten_tool_result_nested_image(self):
|
||||
"""Tool result containing image blocks uses placeholder."""
|
||||
@@ -1010,7 +1009,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value={})),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
@@ -1405,12 +1404,270 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
events.append(event)
|
||||
|
||||
# Should NOT retry — only 1 attempt for auth errors
|
||||
assert attempt_count[0] == 1, (
|
||||
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
|
||||
)
|
||||
assert (
|
||||
attempt_count[0] == 1
|
||||
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert errors, "Expected StreamError"
|
||||
assert errors[0].code == "sdk_stream_error"
|
||||
# Verify user-friendly message (not raw SDK text)
|
||||
assert "Authentication" in errors[0].errorText
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_prompt_too_long_triggers_compaction(self):
|
||||
"""CLI returns ResultMessage(subtype="error") with "Prompt is too long".
|
||||
|
||||
When the Claude CLI rejects the prompt pre-API (model=<synthetic>,
|
||||
duration_api_ms=0), it sends a ResultMessage with is_error=True
|
||||
instead of raising a Python exception. The retry loop must still
|
||||
detect this as a context-length error and trigger compaction.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="error",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] == 1:
|
||||
# First attempt: CLI returns error ResultMessage
|
||||
return self._make_client_mock(result_message=error_result)
|
||||
# Second attempt (after compaction): succeeds
|
||||
return self._make_client_mock(result_message=success_result)
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (CLI error ResultMessage "
|
||||
f"should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
|
||||
|
||||
The SDK internally compacts but the transcript is still too long. It
|
||||
returns subtype="success" (process completed) with result="Prompt is
|
||||
too long" (the actual rejection message). The retry loop must detect
|
||||
this as a context-length error and trigger compaction — the subtype
|
||||
"success" must not fool it into treating this as a real response.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="success",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
yield error_result
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
|
||||
f"result should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
|
||||
|
||||
The SDK returns error type "invalid_request" but puts the actual
|
||||
rejection message ("Prompt is too long") in the content blocks.
|
||||
The retry loop must detect this via content inspection (sdk_error
|
||||
being set confirms it's an error message, not user content).
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
# SDK returns invalid_request with "Prompt is too long" in content.
|
||||
# ResultMessage.result is a non-PTL value ("done") to isolate
|
||||
# the AssistantMessage content detection path exclusively.
|
||||
yield AssistantMessage(
|
||||
content=[TextBlock(text="Prompt is too long")],
|
||||
model="<synthetic>",
|
||||
error="invalid_request",
|
||||
)
|
||||
yield ResultMessage(
|
||||
subtype="success",
|
||||
result="done",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
|
||||
f"too long' should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -22,6 +22,38 @@ from .tool_adapter import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The SDK CLI uses "Task" in older versions and "Agent" in v2.x+.
|
||||
# Shared across all sessions — used by security hooks for sub-agent detection.
|
||||
_SUBAGENT_TOOLS: frozenset[str] = frozenset({"Task", "Agent"})
|
||||
|
||||
# Unicode ranges stripped by _sanitize():
|
||||
# - BiDi overrides (U+202A-U+202E, U+2066-U+2069) can trick reviewers
|
||||
# into misreading code/logs.
|
||||
# - Zero-width characters (U+200B-U+200F, U+FEFF) can hide content.
|
||||
_BIDI_AND_ZW_CHARS = set(
|
||||
chr(c)
|
||||
for r in (range(0x202A, 0x202F), range(0x2066, 0x206A), range(0x200B, 0x2010))
|
||||
for c in r
|
||||
) | {"\ufeff"}
|
||||
|
||||
|
||||
def _sanitize(value: str, max_len: int = 200) -> str:
|
||||
"""Strip control characters and truncate for safe logging.
|
||||
|
||||
Removes C0 (U+0000-U+001F), DEL (U+007F), C1 (U+0080-U+009F),
|
||||
Unicode BiDi overrides, and zero-width characters to prevent
|
||||
log injection and visual spoofing.
|
||||
"""
|
||||
cleaned = "".join(
|
||||
c
|
||||
for c in value
|
||||
if c >= " "
|
||||
and c != "\x7f"
|
||||
and not ("\x80" <= c <= "\x9f")
|
||||
and c not in _BIDI_AND_ZW_CHARS
|
||||
)
|
||||
return cleaned[:max_len]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
@@ -136,11 +168,13 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- SubagentStart: Log sub-agent lifecycle start
|
||||
- SubagentStop: Log sub-agent lifecycle end
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
max_subtasks: Maximum concurrent sub-agent spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
@@ -151,9 +185,19 @@ def create_security_hooks(
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session tracking for Task sub-agent concurrency.
|
||||
# Per-session tracking for sub-agent concurrency.
|
||||
# Set of tool_use_ids that consumed a slot — len() is the active count.
|
||||
task_tool_use_ids: set[str] = set()
|
||||
#
|
||||
# LIMITATION: For background (async) agents the SDK returns the
|
||||
# Agent/Task tool immediately with {isAsync: true}, which triggers
|
||||
# PostToolUse and releases the slot while the agent is still running.
|
||||
# SubagentStop fires later when the background process finishes but
|
||||
# does not currently hold a slot. This means the concurrency limit
|
||||
# only gates *launches*, not true concurrent execution. To fix this
|
||||
# we would need to track background agent_ids separately and release
|
||||
# in SubagentStop, but the SDK does not guarantee SubagentStop fires
|
||||
# for every background agent (e.g. on session abort).
|
||||
subagent_tool_use_ids: set[str] = set()
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
@@ -165,29 +209,22 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
# Rate-limit sub-agent spawns per session.
|
||||
# The SDK CLI renamed "Task" → "Agent" in v2.x; handle both.
|
||||
if tool_name in _SUBAGENT_TOOLS:
|
||||
# Background agents are allowed — the SDK returns immediately
|
||||
# with {isAsync: true} and the model polls via TaskOutput.
|
||||
# Still count them against the concurrency limit.
|
||||
if len(subagent_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
f"[SDK] Sub-agent limit reached ({max_subtasks}), "
|
||||
f"user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
f"Maximum {max_subtasks} concurrent sub-agents. "
|
||||
"Wait for running sub-agents to finish, "
|
||||
"or continue in the main conversation."
|
||||
),
|
||||
)
|
||||
@@ -208,20 +245,20 @@ def create_security_hooks(
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Reserve the Task slot only after all validations pass
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
# Reserve the sub-agent slot only after all validations pass
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id is not None:
|
||||
subagent_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a Task concurrency slot if one was reserved."""
|
||||
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
|
||||
task_tool_use_ids.discard(tool_use_id)
|
||||
def _release_subagent_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a sub-agent concurrency slot if one was reserved."""
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id in subagent_tool_use_ids:
|
||||
subagent_tool_use_ids.discard(tool_use_id)
|
||||
logger.info(
|
||||
"[SDK] Task slot released, active=%d/%d, user=%s",
|
||||
len(task_tool_use_ids),
|
||||
"[SDK] Sub-agent slot released, active=%d/%d, user=%s",
|
||||
len(subagent_tool_use_ids),
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
@@ -241,13 +278,14 @@ def create_security_hooks(
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""), max_len=12)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
@@ -256,7 +294,7 @@ def create_security_hooks(
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
resp_preview = _sanitize(str(tool_response), max_len=100)
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
@@ -280,13 +318,17 @@ def create_security_hooks(
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
error = _sanitize(str(input_data.get("error", "Unknown error")))
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""))
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
error,
|
||||
user_id,
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
@@ -301,20 +343,17 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
transcript_path = _sanitize(
|
||||
str(input_data.get("transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
@@ -323,6 +362,44 @@ def create_security_hooks(
|
||||
on_compact(transcript_path)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_start_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent starts execution."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
logger.info(
|
||||
"[SDK] SubagentStart: agent_id=%s, type=%s, user=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent stops."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
transcript = _sanitize(
|
||||
str(input_data.get("agent_transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] SubagentStop: agent_id=%s, type=%s, user=%s, transcript=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
transcript,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -330,6 +407,8 @@ def create_security_hooks(
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
"SubagentStart": [HookMatcher(matcher="*", hooks=[subagent_start_hook])],
|
||||
"SubagentStop": [HookMatcher(matcher="*", hooks=[subagent_stop_hook])],
|
||||
}
|
||||
|
||||
return hooks
|
||||
|
||||
@@ -5,13 +5,18 @@ They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .security_hooks import (
|
||||
_validate_tool_access,
|
||||
_validate_user_isolation,
|
||||
create_security_hooks,
|
||||
)
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
@@ -132,8 +137,20 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_tool_outputs_allowed():
|
||||
"""tool-outputs/ paths should be allowed, same as tool-results/."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-outputs/12345.txt"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/tool-outputs is."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
@@ -220,8 +237,6 @@ def test_bash_builtin_blocked_message_clarity():
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return (pre, post, post_failure) handlers."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
@@ -231,16 +246,15 @@ def _hooks():
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
async def test_task_background_allowed(_hooks):
|
||||
"""Task with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
tool_use_id="tu-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@@ -354,3 +368,303 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# "Agent" tool name (SDK v2.x+ renamed "Task" → "Agent")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_background_allowed(_hooks):
|
||||
"""Agent with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "x"},
|
||||
},
|
||||
tool_use_id="tu-agent-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_foreground_allowed(_hooks):
|
||||
"""Agent without run_in_background should be allowed."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id="tu-agent-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_agent_counts_against_limit(_hooks):
|
||||
"""Background agents still consume concurrency slots."""
|
||||
pre, _, _ = _hooks
|
||||
# Two background agents fill the limit
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "bg"},
|
||||
},
|
||||
tool_use_id=f"tu-bglimit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
# Third (background or foreground) should be denied
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "over"},
|
||||
},
|
||||
tool_use_id="tu-bglimit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_limit_enforced(_hooks):
|
||||
"""Agent spawns beyond max_subtasks should be denied."""
|
||||
pre, _, _ = _hooks
|
||||
# First two should pass
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-limit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id="tu-agent-limit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_completion(_hooks):
|
||||
"""Completing an Agent should free a slot so new Agents can be spawned."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-comp-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied — at capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-comp-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Complete first agent — frees a slot
|
||||
await post(
|
||||
{"tool_name": "Agent", "tool_input": {}},
|
||||
tool_use_id="tu-agent-comp-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now a new Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after release"}},
|
||||
tool_use_id="tu-agent-comp-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_failure(_hooks):
|
||||
"""A failed Agent should also free its concurrency slot."""
|
||||
pre, _, post_failure = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-fail-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# At capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-fail-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Fail first agent — should free a slot
|
||||
await post_failure(
|
||||
{"tool_name": "Agent", "tool_input": {}, "error": "something broke"},
|
||||
tool_use_id="tu-agent-fail-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# New Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after failure"}},
|
||||
tool_use_id="tu-agent-fail-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_task_agent_share_slots(_hooks):
|
||||
"""Task and Agent share the same concurrency pool."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill one slot with Task, one with Agent
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-agent",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third (either name) should be denied
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-mix-over",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Release the Task slot
|
||||
await post(
|
||||
{"tool_name": "Task", "tool_input": {}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now an Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after task release"}},
|
||||
tool_use_id="tu-mix-new",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStart / SubagentStop hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _subagent_hooks():
|
||||
"""Create hooks and return (subagent_start, subagent_stop) handlers."""
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
start = hooks["SubagentStart"][0].hooks[0]
|
||||
stop = hooks["SubagentStop"][0].hooks[0]
|
||||
return start, stop
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_start_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStart hook should return an empty dict (logging only)."""
|
||||
start, _ = _subagent_hooks
|
||||
result = await start(
|
||||
{"agent_id": "sa-123", "agent_type": "research"},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_stop_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStop hook should return an empty dict (logging only)."""
|
||||
_, stop = _subagent_hooks
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa-123",
|
||||
"agent_type": "research",
|
||||
"agent_transcript_path": "/tmp/transcript.txt",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
|
||||
"""SubagentStart/Stop should sanitize control chars from inputs."""
|
||||
start, stop = _subagent_hooks
|
||||
# Inject control characters (C0, DEL, C1, BiDi overrides, zero-width)
|
||||
# — hook should not raise AND logs must be clean
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await start(
|
||||
{
|
||||
"agent_id": "sa\n-injected\r\x00\x7f",
|
||||
"agent_type": "safe\x80_type\x9f\ttab",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
# Control chars must be stripped from the logged values
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\x80" not in record.message
|
||||
assert "\x9f" not in record.message
|
||||
assert "safe_type" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa\n-injected\x7f",
|
||||
"agent_type": "type\r\x80\x9f",
|
||||
"agent_transcript_path": "/tmp/\x00malicious\npath\u202a\u200b",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\u202a" not in record.message
|
||||
assert "\u200b" not in record.message
|
||||
assert "/tmp/maliciouspath" in caplog.text
|
||||
|
||||
@@ -33,12 +33,24 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..config import ChatConfig, CopilotMode
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -51,6 +63,7 @@ from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
maybe_append_user_message,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
@@ -59,11 +72,14 @@ from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
@@ -77,31 +93,18 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
reset_tool_failure_counters,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -115,9 +118,10 @@ _MAX_STREAM_ATTEMPTS = 3
|
||||
|
||||
# Hard circuit breaker: abort the stream if the model sends this many
|
||||
# consecutive tool calls with empty parameters (a sign of context
|
||||
# saturation or serialization failure). Empty input ({}) is never
|
||||
# legitimate — even one is suspicious, three is conclusive.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 3
|
||||
# saturation or serialization failure). The MCP wrapper now returns
|
||||
# guidance on the first empty call, giving the model a chance to
|
||||
# self-correct. The limit is generous to allow recovery attempts.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 5
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
@@ -127,6 +131,11 @@ _CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"Try breaking your request into smaller parts."
|
||||
)
|
||||
|
||||
# Idle timeout: abort the stream if no meaningful SDK message (only heartbeats)
|
||||
# arrives for this many seconds. This catches hung tool calls (e.g. WebSearch
|
||||
# hanging on a search provider that never responds).
|
||||
_IDLE_TIMEOUT_SECONDS = 10 * 60 # 10 minutes
|
||||
|
||||
# Patterns that indicate the prompt/request exceeds the model's context limit.
|
||||
# Matched case-insensitively against the full exception chain.
|
||||
_PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
|
||||
@@ -567,60 +576,6 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
|
||||
2. **Direct Anthropic** — returns `{}`; subprocess inherits
|
||||
`ANTHROPIC_API_KEY` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
_validate_claude_code_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
# `openrouter_active` checks the flag *and* credential presence.
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path.
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
"""Sanitise a header value: strip newlines/whitespace and cap length."""
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -800,15 +755,11 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
elif msg.role == "assistant":
|
||||
if msg.content:
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
tool_name = func.get("name", "unknown")
|
||||
tool_args = func.get("arguments", "")
|
||||
lines.append(f"You called tool: {tool_name}({tool_args})")
|
||||
# Omit tool_calls — any text representation gets mimicked
|
||||
# by the model. Tool results below provide the context.
|
||||
elif msg.role == "tool":
|
||||
content = msg.content or ""
|
||||
lines.append(f"Tool result: {content}")
|
||||
lines.append(f"Tool output: {content[:500]}")
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
@@ -1268,6 +1219,14 @@ async def _run_stream_attempt(
|
||||
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
# --- Intermediate persistence tracking ---
|
||||
# Flush session messages to DB periodically so page reloads show progress
|
||||
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
|
||||
_last_flush_time = time.monotonic()
|
||||
_msgs_since_flush = 0
|
||||
_FLUSH_INTERVAL_SECONDS = 30.0
|
||||
_FLUSH_MESSAGE_THRESHOLD = 10
|
||||
|
||||
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
|
||||
# suppress SDK cleanup errors that occur when the SSE client disconnects
|
||||
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
|
||||
@@ -1319,6 +1278,8 @@ async def _run_stream_attempt(
|
||||
await client.query(state.query_message, session_id=ctx.session_id)
|
||||
state.transcript_builder.append_user(content=ctx.current_message)
|
||||
|
||||
_last_real_msg_time = time.monotonic()
|
||||
|
||||
async for sdk_msg in _iter_sdk_messages(client):
|
||||
# Heartbeat sentinel — refresh lock and keep SSE alive
|
||||
if sdk_msg is None:
|
||||
@@ -1326,8 +1287,34 @@ async def _run_stream_attempt(
|
||||
for ev in ctx.compaction.emit_start_if_ready():
|
||||
yield ev
|
||||
yield StreamHeartbeat()
|
||||
|
||||
# Idle timeout: if no real SDK message for too long, a tool
|
||||
# call is likely hung (e.g. WebSearch provider not responding).
|
||||
idle_seconds = time.monotonic() - _last_real_msg_time
|
||||
if idle_seconds >= _IDLE_TIMEOUT_SECONDS:
|
||||
logger.error(
|
||||
"%s Idle timeout after %.0fs with no SDK message — "
|
||||
"aborting stream (likely hung tool call)",
|
||||
ctx.log_prefix,
|
||||
idle_seconds,
|
||||
)
|
||||
stream_error_msg = (
|
||||
"A tool call appears to be stuck "
|
||||
"(no response for 10 minutes). "
|
||||
"Please try again."
|
||||
)
|
||||
stream_error_code = "idle_timeout"
|
||||
_append_error_marker(ctx.session, stream_error_msg, retryable=True)
|
||||
yield StreamError(
|
||||
errorText=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
continue
|
||||
|
||||
_last_real_msg_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
|
||||
ctx.log_prefix,
|
||||
@@ -1354,6 +1341,27 @@ async def _run_stream_attempt(
|
||||
error_preview,
|
||||
)
|
||||
|
||||
# Intercept prompt-too-long errors surfaced as
|
||||
# AssistantMessage.error (not as a Python exception).
|
||||
# Re-raise so the outer retry loop can compact the
|
||||
# transcript and retry with reduced context.
|
||||
# Check both error_text and error_preview: sdk_error
|
||||
# being set confirms this is an error message (not user
|
||||
# content), so checking content is safe. The actual
|
||||
# error description (e.g. "Prompt is too long") may be
|
||||
# in the content, not the error type field
|
||||
# (e.g. error="invalid_request", content="Prompt is
|
||||
# too long").
|
||||
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
|
||||
Exception(error_preview)
|
||||
):
|
||||
logger.warning(
|
||||
"%s Prompt-too-long detected via AssistantMessage "
|
||||
"error — raising for retry",
|
||||
ctx.log_prefix,
|
||||
)
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Intercept transient API errors (socket closed,
|
||||
# ECONNRESET) — replace the raw message with a
|
||||
# user-friendly error text and use the retryable
|
||||
@@ -1381,28 +1389,17 @@ async def _run_stream_attempt(
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# Parallel tool execution: pre-launch every ToolUseBlock as an
|
||||
# asyncio.Task the moment its AssistantMessage arrives. The SDK
|
||||
# sends one AssistantMessage per tool call when issuing parallel
|
||||
# calls, so each message is pre-launched independently. The MCP
|
||||
# handlers will await the already-running task instead of executing
|
||||
# fresh, making all concurrent tool calls run in parallel.
|
||||
#
|
||||
# Also determine if the message is a tool-only batch (all content
|
||||
# Determine if the message is a tool-only batch (all content
|
||||
# items are ToolUseBlocks) — such messages have no text output yet,
|
||||
# so we skip the wait_for_stash flush below.
|
||||
#
|
||||
# Note: parallel execution of tools is handled natively by the
|
||||
# SDK CLI via readOnlyHint annotations on tool definitions.
|
||||
is_tool_only = False
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
|
||||
is_tool_only = True
|
||||
# NOTE: Pre-launches are sequential (each await completes
|
||||
# file-ref expansion before the next starts). This is fine
|
||||
# since expansion is typically sub-ms; a future optimisation
|
||||
# could gather all pre-launches concurrently.
|
||||
for tool_use in sdk_msg.content:
|
||||
if isinstance(tool_use, ToolUseBlock):
|
||||
await pre_launch_tool_call(tool_use.name, tool_use.input)
|
||||
else:
|
||||
is_tool_only = False
|
||||
is_tool_only = all(
|
||||
isinstance(item, ToolUseBlock) for item in sdk_msg.content
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
@@ -1459,6 +1456,16 @@ async def _run_stream_attempt(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Check for prompt-too-long regardless of subtype — the
|
||||
# SDK may return subtype="success" with result="Prompt is
|
||||
# too long" when the CLI rejects the prompt before calling
|
||||
# the API (cost_usd=0, no tokens consumed). If we only
|
||||
# check the "error" subtype path, the stream appears to
|
||||
# complete normally, the synthetic error text is stored
|
||||
# in the transcript, and the session grows without bound.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
@@ -1490,6 +1497,23 @@ async def _run_stream_attempt(
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
|
||||
if compact_result.events:
|
||||
# Compaction events end with StreamFinishStep, which maps to
|
||||
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
|
||||
# Close any open text block BEFORE the compaction events so
|
||||
# the text-end arrives before finish-step, preventing
|
||||
# "text-end for missing text part" errors on the frontend.
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(pre_close)
|
||||
# Compaction events bypass the adapter, so sync step state
|
||||
# when a StreamFinishStep is present — otherwise the adapter
|
||||
# will skip StreamStartStep on the next AssistantMessage.
|
||||
if any(
|
||||
isinstance(ev, StreamFinishStep) for ev in compact_result.events
|
||||
):
|
||||
state.adapter.step_open = False
|
||||
for r in pre_close:
|
||||
yield r
|
||||
for ev in compact_result.events:
|
||||
yield ev
|
||||
entries_replaced = False
|
||||
@@ -1536,6 +1560,46 @@ async def _run_stream_attempt(
|
||||
model=sdk_msg.model,
|
||||
)
|
||||
|
||||
# --- Intermediate persistence ---
|
||||
# Flush session messages to DB periodically so page reloads
|
||||
# show progress during long-running turns.
|
||||
#
|
||||
# IMPORTANT: Skip the flush while tool calls are pending
|
||||
# (tool_calls set on assistant but results not yet received).
|
||||
# The DB save is append-only (uses start_sequence), so if we
|
||||
# flush the assistant message before tool_calls are set on it
|
||||
# (text and tool_use arrive as separate SDK events), the
|
||||
# tool_calls update is lost — the next flush starts past it.
|
||||
_msgs_since_flush += 1
|
||||
now = time.monotonic()
|
||||
has_pending_tools = (
|
||||
acc.has_appended_assistant
|
||||
and acc.accumulated_tool_calls
|
||||
and not acc.has_tool_results
|
||||
)
|
||||
if not has_pending_tools and (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(ctx.session))
|
||||
logger.debug(
|
||||
"%s Intermediate flush: %d messages "
|
||||
"(msgs_since=%d, elapsed=%.1fs)",
|
||||
ctx.log_prefix,
|
||||
len(ctx.session.messages),
|
||||
_msgs_since_flush,
|
||||
now - _last_flush_time,
|
||||
)
|
||||
except Exception as flush_err:
|
||||
logger.warning(
|
||||
"%s Intermediate flush failed: %s",
|
||||
ctx.log_prefix,
|
||||
flush_err,
|
||||
)
|
||||
_last_flush_time = now
|
||||
_msgs_since_flush = 0
|
||||
|
||||
if acc.stream_completed:
|
||||
break
|
||||
finally:
|
||||
@@ -1613,6 +1677,7 @@ async def stream_chat_completion_sdk(
|
||||
session: ChatSession | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
mode: CopilotMode | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1621,7 +1686,10 @@ async def stream_chat_completion_sdk(
|
||||
file_ids: Optional workspace file IDs attached to the user's message.
|
||||
Images are embedded as vision content blocks; other files are
|
||||
saved to the SDK working directory for the Read tool.
|
||||
mode: Accepted for signature compatibility with the baseline path.
|
||||
The SDK path does not currently branch on this value.
|
||||
"""
|
||||
_ = mode # SDK path ignores the requested mode.
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -1652,19 +1720,12 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
session.messages.pop()
|
||||
|
||||
# Append the new message to the session if it's not already there
|
||||
new_message_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_message_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
||||
if maybe_append_user_message(session, message, is_user_message):
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message or ""),
|
||||
)
|
||||
|
||||
# Structured log prefix: [SDK][<session>][T<turn>]
|
||||
@@ -1867,7 +1928,10 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
# sdk_cwd routes the CLI's temp dir into the per-session workspace
|
||||
# so sub-agent output files land inside sdk_cwd (see build_sdk_env).
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id, sdk_cwd=sdk_cwd)
|
||||
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
@@ -1926,15 +1990,20 @@ async def stream_chat_completion_sdk(
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
# is what Langfuse (or any OTEL backend) maps to its native
|
||||
# user/session fields.
|
||||
_user_tier = await get_user_tier(user_id) if user_id else None
|
||||
_otel_metadata: dict[str, str] = {
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
}
|
||||
if _user_tier:
|
||||
_otel_metadata["subscription_tier"] = _user_tier.value
|
||||
|
||||
_otel_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
},
|
||||
metadata=_otel_metadata,
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
@@ -2062,13 +2131,22 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
try:
|
||||
async for event in _run_stream_attempt(stream_ctx, state):
|
||||
if not isinstance(event, StreamHeartbeat):
|
||||
if not isinstance(
|
||||
event,
|
||||
(
|
||||
StreamHeartbeat,
|
||||
# Compaction UI events are cosmetic and must not
|
||||
# block retry — they're emitted before the SDK
|
||||
# query on compacted attempts.
|
||||
StreamStartStep,
|
||||
StreamFinishStep,
|
||||
StreamToolInputStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
),
|
||||
):
|
||||
events_yielded += 1
|
||||
yield event
|
||||
# Cancel any pre-launched tasks that were never dispatched
|
||||
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
|
||||
# with the three error-path await cancel_pending_tool_tasks() calls.
|
||||
await cancel_pending_tool_tasks()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -2077,9 +2155,6 @@ async def stream_chat_completion_sdk(
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
)
|
||||
# Cancel any pre-launched tasks so they don't continue executing
|
||||
# against a rolled-back or abandoned session.
|
||||
await cancel_pending_tool_tasks()
|
||||
raise
|
||||
except _HandledStreamError as exc:
|
||||
# _run_stream_attempt already yielded a StreamError and
|
||||
@@ -2111,8 +2186,6 @@ async def stream_chat_completion_sdk(
|
||||
retryable=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
# Cancel any pre-launched tasks from the failed attempt.
|
||||
await cancel_pending_tool_tasks()
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
@@ -2129,9 +2202,6 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Cancel any pre-launched tasks from the failed attempt so they
|
||||
# don't continue executing against the rolled-back session.
|
||||
await cancel_pending_tool_tasks()
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
|
||||
@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
@@ -392,7 +391,7 @@ class TestFlattenThinkingBlocks:
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
"""Mixed blocks: only text survives flattening; thinking and tool_use dropped."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
@@ -403,7 +402,8 @@ class TestFlattenThinkingBlocks:
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
# tool_use blocks are dropped entirely to prevent model mimicry
|
||||
assert "Read" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -438,7 +438,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -497,7 +497,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
@@ -550,7 +550,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -600,7 +600,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -637,7 +637,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -698,7 +698,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user