Compare commits

..

1 Commits

Author SHA1 Message Date
Zamil Majdy
1750c833ee fix(frontend): upgrade Docker Node.js from v21 (EOL) to v22 LTS (#12561)
## Summary
Upgrade the frontend **Docker image** from **Node.js v21** (EOL since
June 2024) to **Node.js v22 LTS** (supported through April 2027).

> **Scope:** This only affects the **Dockerfile** used for local
development (`docker compose`) and CI. It does **not** affect Vercel
(which manages its own Node.js runtime) or Kubernetes (the frontend Helm
chart was removed in Dec 2025 — the frontend is deployed exclusively via
Vercel).

## Why
- Node v21.7.3 has a **known TransformStream race condition bug**
causing `TypeError: controller[kState].transformAlgorithm is not a
function` — this is
[BUILDER-3KF](https://significant-gravitas.sentry.io/issues/BUILDER-3KF)
with **567,000+ Sentry events**
- The error is entirely in Node.js internals
(`node:internal/webstreams/transformstream`), zero first-party code
- Node 21 is **not an LTS release** and has been EOL since June 2024
- `package.json` already declares `"engines": { "node": "22.x" }` — the
Dockerfile was inconsistent
- Node 22.x LTS (v22.22.1) fixes the TransformStream bug
- Next.js 15.4.x requires Node 18.18+, so Node 22 is fully compatible

## Changes
- `autogpt_platform/frontend/Dockerfile`: `node:21-alpine` →
`node:22.22-alpine3.23` (both `base` and `prod` stages)

## Test plan
- [ ] Verify frontend Docker image builds successfully via `docker
compose`
- [ ] Verify frontend starts and serves pages correctly in local Docker
environment
- [ ] Monitor Sentry for BUILDER-3KF — should drop to zero for
Docker-based runs
2026-03-27 13:11:23 +07:00
198 changed files with 1328 additions and 14995 deletions

View File

@@ -1,106 +0,0 @@
---
name: open-pr
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
user-invocable: true
args: "[base-branch] — optional target branch (defaults to dev)."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Open a Pull Request
## Step 1: Pre-flight checks
Before opening the PR:
1. Ensure all changes are committed
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
## Step 2: Test coverage
**This is critical.** Before opening the PR, verify:
### Existing behavior is not broken
- Identify which modules/components your changes touch
- Run the existing test suites for those areas
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
### New behavior has test coverage
- Every new feature, endpoint, or behavior change needs tests
- If you added a new block, add tests for that block
- If you changed API behavior, add or update API tests
- If you changed frontend behavior, verify it doesn't break existing flows
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
## Step 3: Create the PR using the repo template
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
2. Preserve the exact section titles and formatting, including:
- `### Why / What / How`
- `### Changes 🏗️`
- `### Checklist 📋`
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
5. Do not alter the template structure, rename sections, or remove any checklist items
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
```bash
BASE_BRANCH="${BASE_BRANCH:-dev}"
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
PREOF
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
rm "$PR_BODY"
```
## Step 4: Review workflow
### If you have a workspace that allows testing (docker, running backend, etc.)
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
### If you do NOT have a workspace that allows testing
This is common for agents running in worktrees without a full stack. In this case:
1. Run `/pr-review` locally to catch obvious issues before pushing
2. **Comment `/review` on the PR** after creating it to trigger the review bot
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
4. Do NOT proceed or merge until the bot review comes back
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
```bash
# After creating the PR:
PR_NUMBER=$(gh pr view --json number -q .number)
gh pr comment "$PR_NUMBER" --body "/review"
# Then use /pr-address to poll for and address the review when it arrives
```
## Step 5: Address review feedback
Once the review bot or human reviewers leave comments:
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
- Do not merge without human approval.
## Related skills
| Skill | When to use |
|---|---|
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
## Step 6: Post-creation
After the PR is created and review is triggered:
- Share the PR URL with the user
- If waiting on the review bot, let the user know the expected wait time (~30 min)
- Do not merge without human approval

View File

@@ -1,195 +0,0 @@
---
name: setup-repo
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
user-invocable: true
args: "No arguments — interactive setup via prompts."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Repository Setup
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
- A **main** worktree (the primary checkout)
- A **reviews** worktree (for PR reviews)
- **N work branches** (branch1..branchN) for parallel development
## Step 1: Identify the repo
Determine the repo root and parent directory:
```bash
ROOT=$(git rev-parse --show-toplevel)
REPO_NAME=$(basename "$ROOT")
PARENT=$(dirname "$ROOT")
```
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
```bash
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
if [ "$SIBLING_COUNT" -gt 1 ]; then
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
# Use $ROOT as-is; skip renaming/restructuring
else
echo "INFO: Fresh clone detected, proceeding with setup"
fi
```
## Step 2: Ask the user questions
Use AskUserQuestion to gather setup preferences:
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
- These become `branch1` through `branchN`
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
- All work branches and reviews will start from this
## Step 3: Fetch and set up branches
```bash
cd "$ROOT"
git fetch origin
# Create the reviews branch from base (skip if already exists)
if git show-ref --verify --quiet refs/heads/reviews; then
echo "INFO: Branch 'reviews' already exists, skipping"
else
git branch reviews <base-branch>
fi
# Create numbered work branches from base (skip if already exists)
for i in $(seq 1 "$COUNT"); do
if git show-ref --verify --quiet "refs/heads/branch$i"; then
echo "INFO: Branch 'branch$i' already exists, skipping"
else
git branch "branch$i" <base-branch>
fi
done
```
## Step 4: Create worktrees
Create worktrees as siblings to the main checkout:
```bash
if [ -d "$PARENT/reviews" ]; then
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
else
git worktree add "$PARENT/reviews" reviews
fi
for i in $(seq 1 "$COUNT"); do
if [ -d "$PARENT/branch$i" ]; then
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
else
git worktree add "$PARENT/branch$i" "branch$i"
fi
done
```
## Step 5: Set up environment files
**Do NOT assume .env files exist.** For each worktree (including main if needed):
1. Check if `.env` exists in the source worktree for each path
2. If `.env` exists, copy it
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
4. If neither exists, warn the user and list which env files are missing
Env file locations to check (same as the `/worktree` skill — keep these in sync):
- `autogpt_platform/.env`
- `autogpt_platform/backend/.env`
- `autogpt_platform/frontend/.env`
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
```bash
SOURCE="$ROOT"
WORKTREES="reviews"
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
FOUND_ANY_ENV=0
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
if [ -f "$SOURCE/$envpath/.env" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
else
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
fi
done
done
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
echo "WARNING: No environment files or templates were found in the source worktree."
# Use AskUserQuestion to confirm: "Continue setup without env files?"
# If the user declines, stop here and let them set up .env files first.
fi
```
## Step 6: Copy branchlet config
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
```bash
if [ -f "$ROOT/.branchlet.json" ]; then
for wt in $WORKTREES; do
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
done
fi
```
## Step 7: Install dependencies
Install deps in all worktrees. Run these sequentially per worktree:
```bash
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
echo "=== Installing deps for $wt ==="
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
echo "=== Done: $wt ===" ||
echo "=== FAILED: $wt ==="
done
```
This is slow. Run in background if possible and notify when complete.
## Step 8: Verify and report
After setup, verify and report to the user:
```bash
git worktree list
```
Summarize:
- Number of worktrees created
- Which env files were copied vs created from defaults vs missing
- Any warnings or errors encountered
## Final directory layout
```
parent/
main/ # Primary checkout (already exists)
reviews/ # PR review worktree
branch1/ # Work branch 1
branch2/ # Work branch 2
...
branchN/ # Work branch N
```

View File

@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.

View File

@@ -178,7 +178,6 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -31,10 +31,7 @@ from backend.data.model import (
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -621,11 +618,6 @@ async def delete_credential(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id, dry_run=False)
return ChatSession.new(user_id)
@tools_router.post(

View File

@@ -1,146 +0,0 @@
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
import logging
from typing import Optional
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, HTTPException, Security
from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
get_global_rate_limits,
get_usage_status,
reset_user_usage,
)
from backend.data.user import get_user_by_email, get_user_email_by_id
logger = logging.getLogger(__name__)
config = ChatConfig()
router = APIRouter(
prefix="/admin",
tags=["copilot", "admin"],
dependencies=[Security(requires_admin_user)],
)
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
async def _resolve_user_id(
user_id: Optional[str], email: Optional[str]
) -> tuple[str, Optional[str]]:
"""Resolve a user_id and email from the provided parameters.
Returns (user_id, email). Accepts either user_id or email; at least one
must be provided. When both are provided, ``email`` takes precedence.
"""
if email:
user = await get_user_by_email(email)
if not user:
raise HTTPException(
status_code=404, detail="No user found with the provided email."
)
return user.id, email
if not user_id:
raise HTTPException(
status_code=400,
detail="Either user_id or email query parameter is required.",
)
# We have a user_id; try to look up their email for display purposes.
# This is non-critical -- a failure should not block the response.
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return user_id, resolved_email
@router.get(
"/rate_limit",
response_model=UserRateLimitResponse,
summary="Get User Rate Limit",
)
async def get_user_rate_limit(
user_id: Optional[str] = None,
email: Optional[str] = None,
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Get a user's current usage and effective rate limits. Admin-only.
Accepts either ``user_id`` or ``email`` as a query parameter.
When ``email`` is provided the user is looked up by email first.
"""
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
)
@router.post(
"/rate_limit/reset",
response_model=UserRateLimitResponse,
summary="Reset User Rate Limit Usage",
)
async def reset_user_rate_limit(
user_id: str = Body(embed=True),
reset_weekly: bool = Body(False, embed=True),
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
logger.info(
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
admin_user_id,
user_id,
reset_weekly,
)
try:
await reset_user_usage(user_id, reset_weekly=reset_weekly)
except Exception as e:
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
)

View File

@@ -1,263 +0,0 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
app = fastapi.FastAPI()
app.include_router(rate_limit_admin_router)
client = fastapi.testclient.TestClient(app)
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
_TARGET_EMAIL = "target@example.com"
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _mock_usage_status(
daily_used: int = 500_000, weekly_used: int = 3_000_000
) -> CoPilotUsageStatus:
from datetime import UTC, datetime, timedelta
now = datetime.now(UTC)
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
),
weekly=UsageWindow(
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
),
)
def _patch_rate_limit_deps(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
daily_used: int = 500_000,
weekly_used: int = 3_000_000,
):
"""Patch the common rate-limit + user-lookup dependencies."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
def test_get_rate_limit(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test getting rate limit and usage for a user."""
_patch_rate_limit_deps(mocker, target_user_id)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"get_rate_limit",
)
def test_get_rate_limit_by_email(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test looking up rate limits via email instead of user_id."""
_patch_rate_limit_deps(mocker, target_user_id)
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Test that looking up a non-existent email returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
assert response.status_code == 404
def test_get_rate_limit_no_params() -> None:
"""Test that omitting both user_id and email returns 400."""
response = client.get("/admin/rate_limit")
assert response.status_code == 400
def test_reset_user_usage_daily_only(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting only daily usage (default behaviour)."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_only",
)
def test_reset_user_usage_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting both daily and weekly usage."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id, "reset_weekly": True},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_and_weekly",
)
def test_reset_user_usage_redis_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that Redis failure on reset returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
side_effect=Exception("Redis connection refused"),
)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 500
def test_get_rate_limit_email_lookup_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that failing to resolve a user email degrades gracefully."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection lost"),
)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] is None
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that rate limit admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": "test"},
)
assert response.status_code == 403

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -20,7 +20,6 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -31,14 +30,8 @@ from backend.copilot.model import (
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
acquire_reset_lock,
check_rate_limit,
get_daily_reset_count,
get_global_rate_limits,
get_usage_status,
increment_daily_reset_count,
release_reset_lock,
reset_daily_usage,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
@@ -66,16 +59,9 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
from backend.util.exceptions import NotFoundError
config = ChatConfig()
@@ -83,6 +69,8 @@ _UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
async def _validate_and_get_session(
session_id: str,
@@ -113,25 +101,12 @@ class StreamChatRequest(BaseModel):
) # Workspace file IDs attached to this message
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -152,7 +127,6 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -263,7 +237,6 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -272,28 +245,22 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -442,7 +409,6 @@ async def get_session(
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=session.metadata,
)
@@ -455,187 +421,11 @@ async def get_copilot_usage(
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
"""
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
)
class RateLimitResetResponse(BaseModel):
"""Response from resetting the daily rate limit."""
success: bool
credits_charged: int = Field(description="Credits charged (in cents)")
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
@router.post(
"/usage/reset",
status_code=200,
responses={
400: {
"description": "Bad Request (feature disabled or daily limit not reached)"
},
402: {"description": "Payment Required (insufficient credits)"},
429: {
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
},
503: {
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
},
},
)
async def reset_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> RateLimitResetResponse:
"""Reset the daily CoPilot rate limit by spending credits.
Allows users who have hit their daily token limit to spend credits
to reset their daily usage counter and continue working.
Returns 400 if the feature is disabled or the user is not over the limit.
Returns 402 if the user has insufficient credits.
"""
cost = config.rate_limit_reset_cost
if cost <= 0:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available.",
)
if not settings.config.enable_credit:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
if daily_limit <= 0:
raise HTTPException(
status_code=400,
detail="No daily limit is configured — nothing to reset.",
)
# Check max daily resets. get_daily_reset_count returns None when Redis
# is unavailable; reject the reset in that case to prevent unlimited
# free resets when the counter store is down.
reset_count = await get_daily_reset_count(user_id)
if reset_count is None:
raise HTTPException(
status_code=503,
detail="Unable to verify reset eligibility — please try again later.",
)
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
raise HTTPException(
status_code=429,
detail=f"You've used all {config.max_daily_resets} resets for today.",
)
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
if not await acquire_reset_lock(user_id):
raise HTTPException(
status_code=429,
detail="A reset is already in progress. Please try again.",
)
try:
# Verify the user is actually at or over their daily limit.
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
status_code=400,
detail="You have not reached your daily limit yet.",
)
# If the weekly limit is also exhausted, resetting the daily counter
# won't help — the user would still be blocked by the weekly limit.
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
raise HTTPException(
status_code=400,
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
)
# Charge credits.
credit_model = await get_user_credit_model(user_id)
try:
remaining = await credit_model.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
reason="CoPilot daily rate limit reset",
),
)
except InsufficientBalanceError as e:
raise HTTPException(
status_code=402,
detail="Insufficient credits to reset your rate limit.",
) from e
# Reset daily usage in Redis. If this fails, refund the credits
# so the user is not charged for a service they did not receive.
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
# Compensate: refund the charged credits.
refunded = False
try:
await credit_model.top_up_credits(user_id, cost)
refunded = True
logger.warning(
"Refunded %d credits to user %s after Redis reset failure",
cost,
user_id[:8],
)
except Exception:
logger.error(
"CRITICAL: Failed to refund %d credits to user %s "
"after Redis reset failure — manual intervention required",
cost,
user_id[:8],
exc_info=True,
)
if refunded:
raise HTTPException(
status_code=503,
detail="Rate limit reset failed — please try again later. "
"Your credits have not been charged.",
)
raise HTTPException(
status_code=503,
detail="Rate limit reset failed and the automatic refund "
"also failed. Please contact support for assistance.",
)
# Track the reset count for daily cap enforcement.
await increment_daily_reset_count(user_id)
finally:
await release_reset_lock(user_id)
# Return updated usage status.
updated_usage = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
)
return RateLimitResetResponse(
success=True,
credits_charged=cost,
remaining_balance=remaining,
usage=updated_usage,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -736,16 +526,12 @@ async def stream_chat_post(
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
@@ -1108,47 +894,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedTheme(BaseModel):
"""A themed group of suggested prompts."""
name: str
prompts: list[str]
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts grouped by theme."""
themes: list[SuggestedTheme]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts grouped by theme.
Returns personalized quick-action prompts based on the user's
business understanding. Returns empty themes list if no custom
prompts are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None or not understanding.suggested_prompts:
return SuggestedPromptsResponse(themes=[])
themes = [
SuggestedTheme(name=name, prompts=prompts)
for name, prompts in understanding.suggested_prompts.items()
]
return SuggestedPromptsResponse(themes=themes)
# ========== Configuration ==========
@@ -1197,7 +942,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id, dry_run=False)
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -1,7 +1,7 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -368,7 +368,6 @@ def test_usage_returns_daily_and_weekly(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
)
@@ -381,7 +380,6 @@ def test_usage_uses_config_limits(
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
response = client.get("/usage")
@@ -390,7 +388,6 @@ def test_usage_uses_config_limits(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
)
@@ -403,126 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_themes(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with themed prompts gets them back as themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {
"Learn": ["L1", "L2"],
"Create": ["C1"],
}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
data = response.json()
assert "themes" in data
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
assert themes_by_name["Learn"] == ["L1", "L2"]
assert themes_by_name["Create"] == ["C1"]
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty themes list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but empty prompts gets empty themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422

View File

@@ -40,15 +40,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -114,7 +110,6 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -153,7 +148,6 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@@ -230,9 +224,6 @@ async def callback(
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -247,7 +238,6 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -342,11 +332,6 @@ async def delete_credentials(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
@@ -357,11 +342,6 @@ async def delete_credentials(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)

View File

@@ -1,7 +1,6 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import fastapi
import fastapi.testclient
@@ -277,294 +276,3 @@ class TestCreateCredentialNoSecretInResponse:
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -17,6 +17,8 @@ from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
from .db import get_library_agent_by_graph_id, update_library_agent
logger = logging.getLogger(__name__)
@@ -59,17 +61,28 @@ async def add_graph_to_library(
graph_model: GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""Check existing / restore soft-deleted / create new LibraryAgent.
"""Check existing / restore soft-deleted / create new LibraryAgent."""
if existing := await get_library_agent_by_graph_id(
user_id, graph_model.id, graph_model.version
):
return existing
Uses a create-then-catch-UniqueViolationError-then-update pattern on
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
This is more robust than ``upsert`` because Prisma's upsert atomicity
guarantees are not well-documented for all versions.
"""
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
_include = library_agent_include(
user_id, include_nodes=False, include_executions=False
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
)
if deleted_agent and (deleted_agent.isDeleted or deleted_agent.isArchived):
return await update_library_agent(
deleted_agent.id,
user_id,
is_deleted=False,
is_archived=False,
)
try:
added_agent = await prisma.models.LibraryAgent.prisma().create(
@@ -85,32 +98,23 @@ async def add_graph_to_library(
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": settings_json,
"settings": SafeJson(
GraphSettings.from_graph(graph_model).model_dump()
),
},
include=_include,
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
)
except prisma.errors.UniqueViolationError:
# Already exists — update to restore if previously soft-deleted/archived
added_agent = await prisma.models.LibraryAgent.prisma().update(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
data={
"isDeleted": False,
"isArchived": False,
"settings": settings_json,
},
include=_include,
# Race condition: concurrent request created the row between our
# check and create. Re-read instead of crashing.
existing = await get_library_agent_by_graph_id(
user_id, graph_model.id, graph_model.version
)
if added_agent is None:
raise NotFoundError(
f"LibraryAgent for graph #{graph_model.id} "
f"v{graph_model.version} not found after UniqueViolationError"
)
if existing:
return existing
raise # Shouldn't happen, but don't swallow unexpected errors
logger.debug(
f"Added graph #{graph_model.id} v{graph_model.version} "

View File

@@ -1,80 +1,71 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.errors
import pytest
from ._add_to_library import add_graph_to_library
@pytest.mark.asyncio
async def test_add_graph_to_library_create_new_agent() -> None:
"""When no matching LibraryAgent exists, create inserts a new one."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
created_agent = MagicMock(name="CreatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
async def test_add_graph_to_library_restores_archived_agent() -> None:
graph_model = MagicMock(id="graph-id", version=2)
archived_agent = MagicMock(id="library-agent-id", isDeleted=False, isArchived=True)
restored_agent = MagicMock(name="LibraryAgentModel")
with (
patch(
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
new=AsyncMock(return_value=None),
),
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
"backend.api.features.library._add_to_library.update_library_agent",
new=AsyncMock(return_value=restored_agent),
) as mock_update,
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
mock_prisma.return_value.find_unique = AsyncMock(return_value=archived_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
assert create_data["User"] == {"connect": {"id": "user-id"}}
assert create_data["AgentGraph"] == {
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
}
assert create_data["isCreatedByUser"] is False
assert create_data["useGraphIsActiveVersion"] is False
assert result is restored_agent
mock_update.assert_awaited_once_with(
"library-agent-id",
"user-id",
is_deleted=False,
is_archived=False,
)
mock_prisma.return_value.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"""UniqueViolationError on create falls back to update."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
updated_agent = MagicMock(name="UpdatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
async def test_add_graph_to_library_restores_deleted_agent() -> None:
graph_model = MagicMock(id="graph-id", version=2)
deleted_agent = MagicMock(id="library-agent-id", isDeleted=True, isArchived=False)
restored_agent = MagicMock(name="LibraryAgentModel")
with (
patch(
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
new=AsyncMock(return_value=None),
),
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
"backend.api.features.library._add_to_library.update_library_agent",
new=AsyncMock(return_value=restored_agent),
) as mock_update,
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
MagicMock(), message="unique constraint"
)
)
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
mock_prisma.return_value.find_unique = AsyncMock(return_value=deleted_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {
"userId_agentGraphId_agentGraphVersion": {
"userId": "user-id",
"agentGraphId": "graph-id",
"agentGraphVersion": 2,
}
}
update_data = update_call.kwargs["data"]
assert update_data["isDeleted"] is False
assert update_data["isArchived"] is False
assert result is restored_agent
mock_update.assert_awaited_once_with(
"library-agent-id",
"user-id",
is_deleted=False,
is_archived=False,
)
mock_prisma.return_value.create.assert_not_called()

View File

@@ -436,53 +436,32 @@ async def create_library_agent(
async with transaction() as tx:
library_agents = await asyncio.gather(
*(
prisma.models.LibraryAgent.prisma(tx).upsert(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_entry.id,
"agentGraphVersion": graph_entry.version,
}
},
data={
"create": prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
prisma.models.LibraryAgent.prisma(tx).create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
"update": {
"isDeleted": False,
"isArchived": False,
"useGraphIsActiveVersion": True,
"settings": SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
}
},
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -1,6 +1,4 @@
from contextlib import asynccontextmanager
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.enums
import prisma.models
@@ -87,6 +85,10 @@ async def test_get_library_agents(mocker):
async def test_add_agent_to_library(mocker):
await connect()
# Mock the transaction context
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
# Mock data
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
@@ -141,11 +143,13 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
# Mock graph_db.get_graph function that's called to check for HITL blocks
# (lives in _add_to_library.py after refactor, not db.py)
mock_graph_db = mocker.patch(
"backend.api.features.library._add_to_library.graph_db"
@@ -171,27 +175,37 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args
assert create_call_args is not None
# Verify the create data structure
create_data = create_call_args.kwargs["data"]
expected_create = {
# Verify the main structure
expected_data = {
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
}
for key, value in expected_create.items():
assert create_data[key] == value
actual_data = create_call_args[1]["data"]
# Check that all expected fields are present
for key, value in expected_data.items():
assert actual_data[key] == value
# Check that settings field is present and is a SafeJson object
assert "settings" in create_data
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
assert "settings" in actual_data
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
# Check include parameter
assert create_call_args.kwargs["include"] == library_agent_include(
assert create_call_args[1]["include"] == library_agent_include(
"test-user", include_nodes=False, include_executions=False
)
@@ -306,50 +320,3 @@ async def test_update_graph_in_library_allows_archived_library_agent(mocker):
include_archived=True,
)
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
@pytest.mark.asyncio
async def test_create_library_agent_uses_upsert():
"""create_library_agent should use upsert (not create) to handle duplicates."""
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.version = 1
mock_graph.user_id = "user-1"
mock_graph.nodes = []
mock_graph.sub_graphs = []
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
@asynccontextmanager
async def fake_tx():
yield None
with (
patch("backend.api.features.library.db.transaction", fake_tx),
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
patch(
"backend.api.features.library.db.add_generated_agent_image",
new=AsyncMock(),
),
patch(
"backend.api.features.library.model.LibraryAgent.from_db",
return_value=MagicMock(),
),
):
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
result = await db.create_library_agent(mock_graph, "user-1")
assert len(result) == 1
upsert_call = mock_prisma.return_value.upsert.call_args
assert upsert_call is not None
# Verify the upsert where clause uses the composite unique key
where = upsert_call.kwargs["where"]
assert "userId_agentGraphId_agentGraphVersion" in where
# Verify the upsert data has both create and update branches
data = upsert_call.kwargs["data"]
assert "create" in data
assert "update" in data
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False

View File

@@ -12,7 +12,6 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -59,27 +58,14 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture

View File

@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -118,11 +117,6 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -324,11 +318,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.rate_limit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -1,4 +1,3 @@
import re
from typing import Any
from backend.blocks._base import (
@@ -20,33 +19,6 @@ from backend.blocks.llm import (
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
# A true/false answer fits comfortably within this budget.
MIN_LLM_OUTPUT_TOKENS = 16
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
"""Parse an LLM response into a boolean result.
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
response is unambiguous; otherwise it contains a diagnostic message
and *result* defaults to ``False``.
"""
text = response_text.strip().lower()
if text == "true":
return True, None
if text == "false":
return False, None
# Fuzzy match use word boundaries to avoid false positives like "untrue".
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
return True, None
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
return False, None
return False, f"Unclear AI response: '{response_text}'"
class AIConditionBlock(AIBlockBase):
"""
@@ -190,26 +162,54 @@ class AIConditionBlock(AIBlockBase):
]
# Call the LLM
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=MIN_LLM_OUTPUT_TOKENS,
)
# Extract the boolean result from the response
result, error = _parse_boolean_response(response.response)
if error:
yield "error", error
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
try:
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=10, # We only expect a true/false response
)
)
self.prompt = response.prompt
# Extract the boolean result from the response
response_text = response.response.strip().lower()
if response_text == "true":
result = True
elif response_text == "false":
result = False
else:
# If the response is not clear, try to interpret it using word boundaries
import re
# Use word boundaries to avoid false positives like 'untrue' or '10'
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
result = True
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
result = False
else:
# Unclear or conflicting response - default to False and yield error
result = False
yield "error", f"Unclear AI response: '{response.response}'"
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
)
self.prompt = response.prompt
except Exception as e:
# In case of any error, default to False to be safe
result = False
# Log the error but don't fail the block execution
import logging
logger = logging.getLogger(__name__)
logger.error(f"AI condition evaluation failed: {str(e)}")
yield "error", f"AI evaluation failed: {str(e)}"
# Yield results
yield "result", result

View File

@@ -1,147 +0,0 @@
"""Tests for AIConditionBlock regression coverage for max_tokens and error propagation."""
from __future__ import annotations
from typing import cast
import pytest
from backend.blocks.ai_condition import (
MIN_LLM_OUTPUT_TOKENS,
AIConditionBlock,
_parse_boolean_response,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AICredentials,
LLMResponse,
)
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
# ---------------------------------------------------------------------------
# Helper to collect all yields from the async generator
# ---------------------------------------------------------------------------
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
outputs: dict[str, object] = {}
async for name, value in block.run(input_data, credentials=credentials):
outputs[name] = value
return outputs
def _make_input(**overrides) -> AIConditionBlock.Input:
defaults: dict = {
"input_value": "hello@example.com",
"condition": "the input is an email address",
"yes_value": "yes!",
"no_value": "no!",
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return AIConditionBlock.Input(**defaults)
def _mock_llm_response(response_text: str) -> LLMResponse:
return LLMResponse(
raw_response="",
prompt=[],
response=response_text,
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
)
# ---------------------------------------------------------------------------
# _parse_boolean_response unit tests
# ---------------------------------------------------------------------------
class TestParseBooleanResponse:
def test_true_exact(self):
assert _parse_boolean_response("true") == (True, None)
def test_false_exact(self):
assert _parse_boolean_response("false") == (False, None)
def test_true_with_whitespace(self):
assert _parse_boolean_response(" True ") == (True, None)
def test_yes_fuzzy(self):
assert _parse_boolean_response("Yes") == (True, None)
def test_no_fuzzy(self):
assert _parse_boolean_response("no") == (False, None)
def test_one_fuzzy(self):
assert _parse_boolean_response("1") == (True, None)
def test_zero_fuzzy(self):
assert _parse_boolean_response("0") == (False, None)
def test_unclear_response(self):
result, error = _parse_boolean_response("I'm not sure")
assert result is False
assert error is not None
assert "Unclear" in error
def test_conflicting_tokens(self):
result, error = _parse_boolean_response("true and false")
assert result is False
assert error is not None
# ---------------------------------------------------------------------------
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
# ---------------------------------------------------------------------------
class TestMaxTokensRegression:
@pytest.mark.asyncio
async def test_llm_call_receives_min_output_tokens(self):
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) the previous value
of 1 was too low and caused OpenAI to reject the request."""
block = AIConditionBlock()
captured_kwargs: dict = {}
async def spy_llm_call(**kwargs):
captured_kwargs.update(kwargs)
return _mock_llm_response("true")
block.llm_call = spy_llm_call # type: ignore[assignment]
input_data = _make_input()
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
assert captured_kwargs["max_tokens"] == 16
# ---------------------------------------------------------------------------
# Regression: exceptions from llm_call must propagate
# ---------------------------------------------------------------------------
class TestExceptionPropagation:
@pytest.mark.asyncio
async def test_llm_call_exception_propagates(self):
"""If llm_call raises, the exception must NOT be swallowed.
Previously the block caught all exceptions and silently returned
result=False."""
block = AIConditionBlock()
async def boom(**kwargs):
raise RuntimeError("LLM provider error")
block.llm_call = boom # type: ignore[assignment]
input_data = _make_input()
with pytest.raises(RuntimeError, match="LLM provider error"):
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)

View File

@@ -146,21 +146,6 @@ class AutoPilotBlock(Block):
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -247,11 +232,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
async def create_session(self, user_id: str) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return session.session_id
async def execute_copilot(
@@ -382,9 +367,7 @@ class AutoPilotBlock(Block):
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
sid = await self.create_session(execution_context.user_id)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout

View File

@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
description="Reads messages from a Discord channel using a bot token.",
categories={BlockCategory.SOCIAL},
test_input={
"continuous_read": False,

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -9,7 +8,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -43,52 +42,8 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
def _make_mime_text(
@@ -145,16 +100,14 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = serialize_email_recipients(input_data.cc)
message["cc"] = ", ".join(input_data.cc)
if input_data.bcc:
message["bcc"] = serialize_email_recipients(input_data.bcc)
message["bcc"] = ", ".join(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1214,15 +1167,13 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1734,16 +1685,13 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -28,9 +28,9 @@ class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, and description.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It outputs the value passed as input.
It Outputs the value passed as input.
"""
class Input(BlockSchemaInput):
@@ -47,6 +47,12 @@ class AgentInputBlock(Block):
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default_factory=list,
advanced=True,
hidden=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
@@ -59,7 +65,10 @@ class AgentInputBlock(Block):
)
def generate_schema(self):
return copy.deepcopy(self.get_field_schema("value"))
schema = copy.deepcopy(self.get_field_schema("value"))
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field for interface definition
@@ -77,16 +86,18 @@ class AgentInputBlock(Block):
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
},
{
"value": 42,
"value": "Hello, World!",
"name": "input_2",
"description": "Example numeric input.",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
},
],
"test_output": [
("result", "Hello, World!"),
("result", 42),
("result", "Hello, World!"),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
@@ -234,11 +245,13 @@ class AgentShortTextInputBlock(AgentInputBlock):
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
},
],
test_output=[
@@ -272,11 +285,13 @@ class AgentLongTextInputBlock(AgentInputBlock):
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
},
],
test_output=[
@@ -310,11 +325,13 @@ class AgentNumberInputBlock(AgentInputBlock):
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
},
],
test_output=[
@@ -484,12 +501,6 @@ class AgentDropdownInputBlock(AgentInputBlock):
title="Dropdown Options",
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")

View File

@@ -104,18 +104,6 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
if isinstance(value, str) and "/" in value:
stripped = value.split("/", 1)[1]
try:
return cls(stripped)
except ValueError:
return None
return None
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
@@ -724,9 +712,6 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -742,9 +727,6 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -978,8 +960,6 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1039,8 +1019,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"OpenRouter returned empty choices: {response}")
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1077,8 +1061,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"Llama API returned empty choices: {response}")
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1108,8 +1096,6 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1146,9 +1132,6 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -2016,19 +1999,6 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

File diff suppressed because it is too large Load Diff

View File

@@ -4,8 +4,6 @@ import pytest
from backend.blocks import get_blocks
from backend.blocks._base import Block, BlockSchemaInput
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
from backend.data.graph import BaseGraph
from backend.data.model import SchemaField
from backend.util.test import execute_block_test
@@ -281,66 +279,3 @@ class TestAutoCredentialsFieldsValidation:
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
exc_info.value
)
def test_agent_input_block_ignores_legacy_placeholder_values():
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
for backward compatibility with existing agent JSON."""
legacy_data = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
instance = AgentInputBlock.Input.model_construct(**legacy_data)
schema = instance.generate_schema()
assert (
"enum" not in schema
), "AgentInputBlock should not produce enum from legacy placeholder_values"
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
options = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=options
)
schema = instance.generate_schema()
assert schema.get("enum") == options
def test_generate_schema_integration_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with legacy placeholder_values
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
legacy_input_default = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
result = BaseGraph._generate_schema(
(AgentInputBlock.Input, legacy_input_default),
)
url_props = result["properties"]["url"]
assert (
"enum" not in url_props
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks."""
dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"

View File

@@ -207,51 +207,6 @@ class TestXMLParserBlockSecurity:
pass
class TestXMLParserBlockSyntaxErrors:
"""XML syntax errors should raise ValueError (not SyntaxError).
This ensures the base Block.execute() wraps them as BlockExecutionError
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
Sentry).
"""
async def test_unclosed_tag_raises_value_error(self):
"""Unclosed tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "<root><unclosed>"
with pytest.raises(ValueError, match="Unclosed tag"):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_unexpected_closing_tag_raises_value_error(self):
"""Extra closing tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "</unexpected>"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_empty_xml_raises_value_error(self):
"""Empty XML input should raise ValueError."""
block = XMLParserBlock()
with pytest.raises(ValueError, match="XML input is empty"):
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
pass
async def test_syntax_error_from_parser_becomes_value_error(self):
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
block = XMLParserBlock()
# Malformed XML that might trigger a SyntaxError from the parser
bad_xml = "<root><child>no closing"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -488,154 +488,6 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
@@ -957,33 +809,3 @@ class TestUserErrorStatusCodeHandling:
mock_warning.assert_called_once()
mock_exception.assert_not_called()
class TestLlmModelMissing:
"""Test that LlmModel handles provider-prefixed model names."""
def test_provider_prefixed_model_resolves(self):
"""Provider-prefixed model string should resolve to the correct enum member."""
assert (
llm.LlmModel("anthropic/claude-sonnet-4-6")
== llm.LlmModel.CLAUDE_4_6_SONNET
)
def test_bare_model_still_works(self):
"""Bare (non-prefixed) model string should still resolve correctly."""
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
def test_invalid_prefixed_model_raises(self):
"""Unknown provider-prefixed model string should raise ValueError."""
with pytest.raises(ValueError):
llm.LlmModel("invalid/nonexistent-model")
def test_slash_containing_value_direct_lookup(self):
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
def test_double_prefixed_slash_model(self):
"""Double-prefixed value should still resolve by stripping first prefix."""
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)

View File

@@ -1,87 +0,0 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,7 +1074,6 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1106,7 +1105,6 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -1,202 +0,0 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

View File

@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise ValueError("Unexpected closing tag in XML input.")
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
)
if depth != 0:
raise ValueError("Unclosed tag detected in XML input.")
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
@@ -76,7 +76,4 @@ class XMLParserBlock(Block):
except ValueError as val_e:
raise ValueError(f"Validation error for dict:{val_e}") from val_e
except SyntaxError as syn_e:
# Raise as ValueError so the base Block.execute() wraps it as
# BlockExecutionError (expected user-caused failure) instead of
# BlockUnknownError (unexpected platform error that alerts Sentry).
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e

View File

@@ -9,16 +9,12 @@ shared tool registry as the SDK path.
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.context import set_execution_context
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -52,17 +48,7 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallResult,
tool_call_loop,
)
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -73,247 +59,6 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
can be module-level functions instead of deeply nested closures.
"""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
async def _baseline_llm_caller(
messages: list[dict[str, Any]],
tools: Sequence[Any],
*,
state: _BaselineStreamState,
) -> LLMLoopResponse:
"""Stream an OpenAI-compatible response and collect results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
round_text = ""
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
stream_options={"include_usage": True},
)
else:
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
)
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += delta.content
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=delta.content)
)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
LLMToolCall(
id=tc["id"],
name=tc["name"],
arguments=tc["arguments"] or "{}",
)
for tc in tool_calls_by_index.values()
]
return LLMLoopResponse(
response_text=round_text or None,
tool_calls=llm_tool_calls,
raw_response=None, # Not needed for baseline conversation updater
prompt_tokens=0, # Tracked via state accumulators
completion_tokens=0,
)
async def _baseline_tool_executor(
tool_call: LLMToolCall,
tools: Sequence[Any],
*,
state: _BaselineStreamState,
user_id: str | None,
session: ChatSession,
) -> ToolCallResult:
"""Execute a tool via the copilot tool registry.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
tool_call_id = tool_call.id
tool_name = tool_call.name
raw_args = tool_call.arguments or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=parse_error,
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
)
state.pending_events.append(
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
)
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=tool_output,
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=error_output,
is_error=True,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": tc.arguments},
}
for tc in response.tool_calls
]
messages.append(assistant_msg)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
@@ -458,9 +203,6 @@ async def stream_chat_completion_baseline(
tools = get_available_tools()
# Propagate execution context so tool handlers can read session-level flags.
set_execution_context(user_id, session)
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
@@ -477,32 +219,191 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
_bound_tool_executor = partial(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
try:
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
if loop_result and not loop_result.finished_naturally:
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
@@ -517,28 +418,11 @@ async def stream_chat_completion_baseline(
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text block. The llm_caller's finally block
# already appended StreamFinishStep to pending_events, so we must
# insert StreamTextEnd *before* StreamFinishStep to preserve the
# protocol ordering:
# StreamStartStep -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Appending (or yielding directly) would place it after
# StreamFinishStep, violating the protocol.
if state.text_started:
# Find the last StreamFinishStep and insert before it.
insert_pos = len(state.pending_events)
for i in range(len(state.pending_events) - 1, -1, -1):
if isinstance(state.pending_events[i], StreamFinishStep):
insert_pos = i
break
state.pending_events.insert(
insert_pos, StreamTextEnd(id=state.text_block_id)
)
# Drain pending events in correct order
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
@@ -558,21 +442,26 @@ async def stream_chat_completion_baseline(
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
state.turn_prompt_tokens == 0
and state.turn_completion_tokens == 0
and not (_stream_error and not state.assistant_text)
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
):
state.turn_prompt_tokens = max(
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
state.turn_prompt_tokens,
state.turn_completion_tokens,
turn_prompt_tokens,
turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
@@ -582,15 +471,15 @@ async def stream_chat_completion_baseline(
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if state.assistant_text:
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=state.assistant_text)
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
@@ -602,11 +491,11 @@ async def stream_chat_completion_baseline(
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id, dry_run=False)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -91,20 +91,6 @@ class ChatConfig(BaseSettings):
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Cost (in credits / cents) to reset the daily rate limit using credits.
# When a user hits their daily limit, they can spend this amount to reset
# the daily counter and keep working. Set to 0 to disable the feature.
rate_limit_reset_cost: int = Field(
default=500,
ge=0,
description="Credit cost (in cents) for resetting the daily rate limit. 0 = disabled.",
)
max_daily_resets: int = Field(
default=5,
ge=0,
description="Maximum number of credit-based rate limit resets per user per day. 0 = unlimited.",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -178,7 +164,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``build_sdk_env``.
present — mirrors the fallback logic in ``_build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -18,13 +18,7 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import (
ChatMessage,
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
invalidate_session_cache,
)
from .model import ChatMessage, ChatSession, ChatSessionInfo
logger = logging.getLogger(__name__)
@@ -41,7 +35,6 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
async def create_chat_session(
session_id: str,
user_id: str,
metadata: ChatSessionMetadata | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -50,7 +43,6 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -65,12 +57,7 @@ async def update_chat_session(
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
"""Update a chat session's mutable fields.
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
it is set once at creation time and treated as immutable for the lifetime
of the session.
"""
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
@@ -230,9 +217,6 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
if msg.get("duration_ms") is not None:
data["durationMs"] = msg["duration_ms"]
messages_data.append(data)
# Run create_many and session update in parallel within transaction
@@ -375,22 +359,3 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
order={"sequence": "desc"},
)
if last_msg:
await PrismaChatMessage.prisma().update(
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)

View File

@@ -46,16 +46,6 @@ def _get_session_cache_key(session_id: str) -> str:
# ===================== Chat data models ===================== #
class ChatSessionMetadata(BaseModel):
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
Add new session-level flags here instead of adding DB columns —
no migration required for new fields as long as a default is provided.
"""
dry_run: bool = False
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -64,7 +54,6 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
duration_ms: int | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -77,7 +66,6 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
duration_ms=prisma_message.durationMs,
)
@@ -100,12 +88,6 @@ class ChatSessionInfo(BaseModel):
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
metadata: ChatSessionMetadata = ChatSessionMetadata()
@property
def dry_run(self) -> bool:
"""Convenience accessor for ``metadata.dry_run``."""
return self.metadata.dry_run
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -119,10 +101,6 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Parse typed metadata from the JSON column.
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
metadata = ChatSessionMetadata.model_validate(raw_metadata)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
@@ -148,7 +126,6 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
metadata=metadata,
)
@@ -156,7 +133,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str, *, dry_run: bool) -> Self:
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -166,7 +143,6 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
)
@classmethod
@@ -554,7 +530,6 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
metadata=session.metadata,
)
existing_message_count = 0
@@ -632,27 +607,21 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id, dry_run=dry_run)
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
metadata=session.metadata,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")

View File

@@ -46,7 +46,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123", dry_run=False)
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
@@ -57,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -75,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -90,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
@@ -241,7 +241,7 @@ _raw_tc2 = {
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
@@ -254,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
@@ -267,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
@@ -282,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
@@ -300,7 +300,7 @@ def test_add_tool_call_multiple_times():
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
@@ -352,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(

View File

@@ -107,13 +107,6 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
After building the file, reference it with `@@agptfile:` in other tools:
`@@agptfile:/home/user/report.md`
### Web search best practices
- If 3 similar web searches don't return the specific data you need, conclude
it isn't publicly available and work with what you have.
- Prefer fewer, well-targeted searches over many variations of the same query.
- When spawning sub-agents for research, ensure each has a distinct
non-overlapping scope to avoid redundant searches.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
@@ -212,10 +205,9 @@ Important files (code, configs, outputs) should be saved to workspace to ensure
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
These files are on the host filesystem — `bash_exec` runs in the sandbox and
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
where SDK tool-results are NOT stored.
always use `read_file` or `Read` (NOT `read_workspace_file`).
`read_workspace_file` reads from cloud workspace storage, where SDK
tool-results are NOT stored.
{_SHARED_TOOL_NOTES}{extra_notes}"""

View File

@@ -36,10 +36,6 @@ class CoPilotUsageStatus(BaseModel):
daily: UsageWindow
weekly: UsageWindow
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
)
class RateLimitExceeded(Exception):
@@ -65,7 +61,6 @@ async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
rate_limit_reset_cost: int = 0,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
@@ -73,7 +68,6 @@ async def get_usage_status(
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
Returns:
CoPilotUsageStatus with current usage and limits.
@@ -103,7 +97,6 @@ async def get_usage_status(
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
reset_cost=rate_limit_reset_cost,
)
@@ -148,110 +141,6 @@ async def check_rate_limit(
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
"""Reset a user's daily token usage counter in Redis.
Called after a user pays credits to extend their daily limit.
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
(clamped to 0) so the user effectively gets one extra day's worth of
weekly capacity.
Args:
user_id: The user's ID.
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
Fails open: returns False if Redis is unavailable (consistent with
the fail-open design of this module).
"""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
# (weekly) either both execute or neither does. This prevents the
# scenario where the daily counter is cleared but the weekly
# counter is not decremented — which would let the caller refund
# credits even though the daily limit was already reset.
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
pipe = redis.pipeline(transaction=True)
pipe.delete(d_key)
if w_key is not None:
pipe.decrby(w_key, daily_token_limit)
results = await pipe.execute()
# Clamp negative weekly counter to 0 (best-effort; not critical).
if w_key is not None:
new_val = results[1] # DECRBY result
if new_val < 0:
await redis.set(w_key, 0, keepttl=True)
logger.info("Reset daily usage for user %s", user_id[:8])
return True
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for resetting daily usage")
return False
_RESET_LOCK_PREFIX = "copilot:reset_lock"
_RESET_COUNT_PREFIX = "copilot:reset_count"
async def acquire_reset_lock(user_id: str, ttl_seconds: int = 10) -> bool:
"""Acquire a short-lived lock to serialize rate limit resets per user."""
try:
redis = await get_redis_async()
key = f"{_RESET_LOCK_PREFIX}:{user_id}"
return bool(await redis.set(key, "1", nx=True, ex=ttl_seconds))
except (RedisError, ConnectionError, OSError) as exc:
logger.warning("Redis unavailable for reset lock, rejecting reset: %s", exc)
return False
async def release_reset_lock(user_id: str) -> None:
"""Release the per-user reset lock."""
try:
redis = await get_redis_async()
await redis.delete(f"{_RESET_LOCK_PREFIX}:{user_id}")
except (RedisError, ConnectionError, OSError):
pass # Lock will expire via TTL
async def get_daily_reset_count(user_id: str) -> int | None:
"""Get how many times the user has reset today.
Returns None when Redis is unavailable so callers can fail-closed
for billed operations (as opposed to failing open for read-only
rate-limit checks).
"""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
val = await redis.get(key)
return int(val or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for reading daily reset count")
return None
async def increment_daily_reset_count(user_id: str) -> None:
"""Increment and track how many resets this user has done today."""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
pipe = redis.pipeline(transaction=True)
pipe.incr(key)
seconds_until_reset = int((_daily_reset_time(now=now) - now).total_seconds())
pipe.expire(key, max(seconds_until_reset, 1))
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for tracking reset count")
async def record_token_usage(
user_id: str,
prompt_tokens: int,
@@ -342,67 +231,6 @@ async def record_token_usage(
)
async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit) tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
from backend.util.feature_flag import Flag, get_feature_flag_value
daily_raw = await get_feature_flag_value(
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
)
weekly_raw = await get_feature_flag_value(
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
)
try:
daily = max(0, int(daily_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
daily = config_daily
try:
weekly = max(0, int(weekly_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
weekly = config_weekly
return daily, weekly
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
"""Reset a user's usage counters.
Always deletes the daily Redis key. When *reset_weekly* is ``True``,
the weekly key is deleted as well.
Unlike read paths (``get_usage_status``, ``check_rate_limit``) which
fail-open on Redis errors, resets intentionally re-raise so the caller
knows the operation did not succeed. A silent failure here would leave
the admin believing the counters were zeroed when they were not.
"""
now = datetime.now(UTC)
keys_to_delete = [_daily_key(user_id, now=now)]
if reset_weekly:
keys_to_delete.append(_weekly_key(user_id, now=now))
try:
redis = await get_redis_async()
await redis.delete(*keys_to_delete)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for resetting user usage")
raise
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------

View File

@@ -12,7 +12,6 @@ from .rate_limit import (
check_rate_limit,
get_usage_status,
record_token_usage,
reset_daily_usage,
)
_USER = "test-user-rl"
@@ -333,91 +332,3 @@ class TestRecordTokenUsage:
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# ---------------------------------------------------------------------------
# reset_daily_usage
# ---------------------------------------------------------------------------
class TestResetDailyUsage:
@staticmethod
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[1, decrby_result])
return pipe
@pytest.mark.asyncio
async def test_deletes_daily_key(self):
mock_pipe = self._make_pipeline_mock(decrby_result=0)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
assert result is True
mock_pipe.delete.assert_called_once()
@pytest.mark.asyncio
async def test_reduces_weekly_usage_via_decrby(self):
"""Weekly counter should be reduced via DECRBY in the pipeline."""
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
@pytest.mark.asyncio
async def test_clamps_negative_weekly_to_zero(self):
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_no_weekly_reduction_when_daily_limit_zero(self):
"""When daily_token_limit is 0, weekly counter should not be touched."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=0)
mock_pipe.delete.assert_called_once()
mock_pipe.decrby.assert_not_called()
@pytest.mark.asyncio
async def test_returns_false_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
assert result is False

View File

@@ -1,294 +0,0 @@
"""Unit tests for the POST /usage/reset endpoint."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from backend.api.features.chat.routes import reset_copilot_usage
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.util.exceptions import InsufficientBalanceError
# Minimal config mock matching ChatConfig fields used by the endpoint.
def _make_config(
rate_limit_reset_cost: int = 500,
daily_token_limit: int = 2_500_000,
weekly_token_limit: int = 12_500_000,
max_daily_resets: int = 5,
):
cfg = MagicMock()
cfg.rate_limit_reset_cost = rate_limit_reset_cost
cfg.daily_token_limit = daily_token_limit
cfg.weekly_token_limit = weekly_token_limit
cfg.max_daily_resets = max_daily_resets
return cfg
def _usage(daily_used: int = 3_000_000, daily_limit: int = 2_500_000):
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_limit,
resets_at=datetime.now(UTC) + timedelta(hours=6),
),
weekly=UsageWindow(
used=5_000_000,
limit=12_500_000,
resets_at=datetime.now(UTC) + timedelta(days=3),
),
)
_MODULE = "backend.api.features.chat.routes"
def _mock_settings(enable_credit: bool = True):
"""Return a mock Settings object with the given enable_credit flag."""
mock = MagicMock()
mock.config.enable_credit = enable_credit
return mock
@pytest.mark.asyncio
class TestResetCopilotUsage:
async def test_feature_disabled_returns_400(self):
"""When rate_limit_reset_cost=0, endpoint returns 400."""
with patch(f"{_MODULE}.config", _make_config(rate_limit_reset_cost=0)):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "not available" in exc_info.value.detail
async def test_no_daily_limit_returns_400(self):
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "nothing to reset" in exc_info.value.detail.lower()
async def test_not_at_limit_returns_400(self):
"""When user hasn't hit their daily limit, returns 400."""
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage(daily_used=1_000_000)),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "not reached" in exc_info.value.detail
mock_release.assert_awaited_once()
async def test_insufficient_credits_returns_402(self):
"""When user doesn't have enough credits, returns 402."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.side_effect = InsufficientBalanceError(
message="Insufficient balance",
user_id="user-1",
balance=50,
amount=200,
)
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 402
mock_release.assert_awaited_once()
async def test_happy_path(self):
"""Successful reset: charges credits, resets usage, returns response."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500 # remaining balance
cfg = _make_config()
updated_usage = _usage(daily_used=0)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(side_effect=[_usage(), updated_usage]),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(
f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=True)
) as mock_reset,
patch(f"{_MODULE}.increment_daily_reset_count", AsyncMock()) as mock_incr,
):
result = await reset_copilot_usage(user_id="user-1")
assert result.success is True
assert result.credits_charged == 500
assert result.remaining_balance == 1500
mock_reset.assert_awaited_once()
mock_incr.assert_awaited_once()
async def test_max_daily_resets_exceeded(self):
"""When user has exhausted daily resets, returns 429."""
cfg = _make_config(max_daily_resets=3)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 429
async def test_credit_system_disabled_returns_400(self):
"""When enable_credit=False, endpoint returns 400."""
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings(enable_credit=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "credit system is disabled" in exc_info.value.detail.lower()
async def test_weekly_limit_exhausted_returns_400(self):
"""When the weekly limit is also exhausted, resetting daily won't help."""
cfg = _make_config()
weekly_exhausted = CoPilotUsageStatus(
daily=UsageWindow(
used=3_000_000,
limit=2_500_000,
resets_at=datetime.now(UTC) + timedelta(hours=6),
),
weekly=UsageWindow(
used=12_500_000,
limit=12_500_000,
resets_at=datetime.now(UTC) + timedelta(days=3),
),
)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=weekly_exhausted),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "weekly" in exc_info.value.detail.lower()
mock_release.assert_awaited_once()
async def test_redis_failure_for_reset_count_returns_503(self):
"""When Redis is unavailable for get_daily_reset_count, returns 503."""
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "verify" in exc_info.value.detail.lower()
async def test_redis_reset_failure_refunds_credits(self):
"""When reset_daily_usage fails, credits are refunded and 503 returned."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "not been charged" in exc_info.value.detail
mock_credit_model.top_up_credits.assert_awaited_once()
async def test_redis_reset_failure_refund_also_fails(self):
"""When both reset and refund fail, error message reflects the truth."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500
mock_credit_model.top_up_credits.side_effect = RuntimeError("db down")
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "contact support" in exc_info.value.detail.lower()

View File

@@ -67,17 +67,9 @@ These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
- For dropdown/select inputs, use **AgentDropdownInputBlock** instead (see below)
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
- Specialized input block that presents a dropdown/select to the user
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
- Optional: `title`, `description`, `value` (default selection)
- Output: `result` — the user-selected value at runtime
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs

View File

@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user", dry_run=False)
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------

View File

@@ -25,64 +25,24 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
Use this helper in any copilot SDK test that needs a well-formed
transcript without hitting the real storage layer.
Delegates to ``build_structured_transcript`` — plain content strings
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
assistant messages.
"""
# Cast widening: tuple[str, str] is structurally compatible with
# tuple[str, str | list[dict]] but list invariance requires explicit
# annotation.
widened: list[tuple[str, str | list[dict]]] = list(pairs)
return build_structured_transcript(widened)
def build_structured_transcript(
entries: list[tuple[str, str | list[dict]]],
) -> str:
"""Build a JSONL transcript with structured content blocks.
Each entry is (role, content) where content is either a plain string
(for user messages) or a list of content block dicts (for assistant
messages with thinking/tool_use/text blocks).
Example::
build_structured_transcript([
("user", "Hello"),
("assistant", [
{"type": "thinking", "thinking": "...", "signature": "sig1"},
{"type": "text", "text": "Hi there"},
]),
])
"""
lines: list[str] = []
last_uuid: str | None = None
for role, content in entries:
for role, content in pairs:
uid = str(uuid4())
entry_type = "assistant" if role == "assistant" else "user"
if role == "assistant" and isinstance(content, list):
msg: dict = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": content,
"stop_reason": "end_turn",
"stop_sequence": None,
}
elif role == "assistant":
msg = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
else:
msg = {"role": role, "content": content}
msg: dict = {"role": role, "content": content}
if role == "assistant":
msg.update(
{
"model": "",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
)
entry = {
"type": entry_type,
"uuid": uid,

View File

@@ -275,7 +275,7 @@ class TestCompactionE2E:
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user", dry_run=False)
session = ChatSession.new(user_id="test-user")
tracker.on_compact(str(session_file))
# --- Step 8: Next SDK message arrives → emit_start ---
@@ -376,7 +376,7 @@ class TestCompactionE2E:
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test", dry_run=False)
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---

View File

@@ -1,68 +0,0 @@
"""SDK environment variable builder — importable without circular deps.
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
can reuse the same subscription / OpenRouter / direct-Anthropic logic
without pulling in the full copilot service module (which would create a
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
"""
from __future__ import annotations
from backend.copilot.config import ChatConfig
from backend.copilot.sdk.subscription import validate_subscription
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
# A singleton would require importing service.py which causes the circular dep
# this module was created to avoid.
config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env

View File

@@ -1,242 +0,0 @@
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
from unittest.mock import patch
import pytest
from backend.copilot.config import ChatConfig
# ---------------------------------------------------------------------------
# Helpers — build a ChatConfig with explicit field values so tests don't
# depend on real environment variables.
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
}
defaults.update(overrides)
return ChatConfig(**defaults)
# ---------------------------------------------------------------------------
# Mode 1 — Subscription auth
# ---------------------------------------------------------------------------
class TestBuildSdkEnvSubscription:
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_returns_blanked_keys(self, mock_validate):
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
mock_validate.assert_called_once()
@patch(
"backend.copilot.sdk.env.validate_subscription",
side_effect=RuntimeError("CLI not found"),
)
def test_propagates_validation_error(self, mock_validate):
"""If validate_subscription fails, the error bubbles up."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
with pytest.raises(RuntimeError, match="CLI not found"):
build_sdk_env()
# ---------------------------------------------------------------------------
# Mode 2 — Direct Anthropic (no OpenRouter)
# ---------------------------------------------------------------------------
class TestBuildSdkEnvDirectAnthropic:
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
def test_returns_empty_dict_when_openrouter_inactive(self):
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
# Force api_key to None after construction (field_validator may pick up env vars)
object.__setattr__(cfg, "api_key", None)
assert not cfg.openrouter_active
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
# ---------------------------------------------------------------------------
# Mode 3 — OpenRouter proxy
# ---------------------------------------------------------------------------
class TestBuildSdkEnvOpenRouter:
"""When OpenRouter is active, return proxy env vars."""
def _openrouter_config(self, **overrides):
defaults = {
"use_openrouter": True,
"api_key": "sk-or-test-key",
"base_url": "https://openrouter.ai/api/v1",
}
defaults.update(overrides)
return _make_config(**defaults)
def test_basic_openrouter_env(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
assert result["ANTHROPIC_API_KEY"] == ""
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
def test_strips_trailing_v1(self):
"""The /v1 suffix is stripped from the base URL."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_strips_trailing_v1_and_slash(self):
"""Trailing slash before /v1 strip is handled."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# rstrip("/") first, then remove /v1
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_no_v1_suffix_left_alone(self):
"""A base URL without /v1 is used as-is."""
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
def test_session_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="sess-123")
assert "ANTHROPIC_CUSTOM_HEADERS" in result
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_user_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(user_id="user-456")
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_both_headers(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="s1", user_id="u2")
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
assert "x-session-id: s1" in headers
assert "x-user-id: u2" in headers
# They should be newline-separated
assert "\n" in headers
def test_header_sanitisation_strips_newlines(self):
"""Newlines/carriage-returns in header values are stripped."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="bad\r\nvalue")
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
# The _safe helper removes \r and \n
assert "\r" not in header_val.split(": ", 1)[1]
assert "badvalue" in header_val
def test_header_value_truncated_to_128_chars(self):
"""Header values are truncated to 128 characters."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
long_id = "x" * 200
result = build_sdk_env(session_id=long_id)
# The value after "x-session-id: " should be at most 128 chars
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
value = header_line.split(": ", 1)[1]
assert len(value) == 128
# ---------------------------------------------------------------------------
# Mode priority
# ---------------------------------------------------------------------------
class TestBuildSdkEnvModePriority:
"""Subscription mode takes precedence over OpenRouter."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_subscription_overrides_openrouter(self, mock_validate):
cfg = _make_config(
use_claude_code_subscription=True,
use_openrouter=True,
api_key="sk-or-key",
base_url="https://openrouter.ai/api/v1",
)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# Should get subscription result, not OpenRouter
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}

View File

@@ -442,11 +442,8 @@ class TestCompactTranscript:
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert len(msgs) == 2
assert msgs[1]["content"] == "Summarized response"
# The last assistant entry is preserved verbatim from original
assert msgs[2]["content"] == "Details"
@pytest.mark.asyncio
async def test_returns_none_on_compression_failure(self, mock_chat_config):

View File

@@ -15,7 +15,6 @@ from claude_agent_sdk import (
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
@@ -27,7 +26,6 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -77,12 +75,6 @@ class SDKResponseAdapter:
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif sdk_message.subtype == "task_progress":
# Emit a heartbeat so publish_chunk is called during long
# sub-agent runs. Without this, the Redis stream and meta
# key TTLs expire during gaps where no real chunks are
# produced (task_progress events were previously silent).
responses.append(StreamHeartbeat())
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
@@ -108,11 +100,6 @@ class SDKResponseAdapter:
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ThinkingBlock):
# Thinking blocks are preserved in the transcript but
# not streamed to the frontend — skip silently.
pass
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)

View File

@@ -18,7 +18,6 @@ from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -60,14 +59,6 @@ def test_system_non_init_emits_nothing():
assert results == []
def test_task_progress_emits_heartbeat():
"""task_progress events emit a StreamHeartbeat to keep Redis TTL alive."""
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="task_progress", data={}))
assert len(results) == 1
assert isinstance(results[0], StreamHeartbeat)
# -- AssistantMessage with TextBlock -----------------------------------------

View File

@@ -124,11 +124,8 @@ class TestScenarioCompactAndRetry:
assert result != original # Must be different
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert len(msgs) == 2
assert msgs[0]["content"] == "[summary of conversation]"
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Long answer 2"
def test_compacted_transcript_loads_into_builder(self):
"""TranscriptBuilder can load a compacted transcript and continue."""
@@ -740,10 +737,7 @@ class TestRetryEdgeCases:
assert result is not None
assert result != transcript
msgs = _transcript_to_messages(result)
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Answer 19"
assert len(msgs) == 2
def test_messages_to_transcript_roundtrip_preserves_content(self):
"""Verify messages → transcript → messages preserves all content."""
@@ -1010,7 +1004,7 @@ def _make_sdk_patches(
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
(f"{_SVC}.set_execution_context", {}),
(

View File

@@ -313,7 +313,8 @@ def create_security_hooks(
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
trigger,
user_id,
transcript_path,

View File

@@ -11,11 +11,7 @@ import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import (
_validate_tool_access,
_validate_user_isolation,
create_security_hooks,
)
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
@@ -224,6 +220,8 @@ def test_bash_builtin_blocked_message_clarity():
@pytest.fixture()
def _hooks():
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
post = hooks["PostToolUse"][0].hooks[0]

View File

@@ -77,9 +77,9 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
cancel_pending_tool_tasks,
create_copilot_mcp_server,
@@ -185,24 +185,6 @@ def _is_prompt_too_long(err: BaseException) -> bool:
return False
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
Two known patterns occur when ``GeneratorExit`` tears down the async
generator and the SDK's ``__aexit__`` runs in a different context/task:
* ``RuntimeError``: cancel scope exited in wrong task (anyio)
* ``ValueError``: ContextVar token created in a different Context (OTEL)
These are suppressed to avoid polluting Sentry with non-actionable noise.
"""
if isinstance(exc, RuntimeError) and "cancel scope" in str(exc):
return True
if isinstance(exc, ValueError) and "was created in a different Context" in str(exc):
return True
return False
def _is_tool_only_message(sdk_msg: object) -> bool:
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
@@ -427,63 +409,6 @@ _HEARTBEAT_INTERVAL = 10.0 # seconds
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
async def _safe_close_sdk_client(
sdk_client: ClaudeSDKClient,
log_prefix: str,
) -> None:
"""Close a ClaudeSDKClient, suppressing errors from client disconnect.
When the SSE client disconnects mid-stream, ``GeneratorExit`` propagates
through the async generator stack and causes ``ClaudeSDKClient.__aexit__``
to run in a different async context or task than where the client was
opened. This triggers two known error classes:
* ``ValueError``: ``<Token var=<ContextVar name='current_context'>>
was created in a different Context`` — OpenTelemetry's
``context.detach()`` fails because the OTEL context token was
created in the original generator coroutine but detach runs in
the GC / cleanup coroutine (Sentry: AUTOGPT-SERVER-8BT).
* ``RuntimeError``: ``Attempted to exit cancel scope in a different
task than it was entered in`` — anyio's ``TaskGroup.__aexit__``
detects that the cancel scope was entered in one task but is
being exited in another (Sentry: AUTOGPT-SERVER-8BW).
Both are harmless — the TCP connection is already dead and no
resources leak. Logging them at ``debug`` level keeps observability
without polluting Sentry.
"""
try:
await sdk_client.__aexit__(None, None, None)
except (ValueError, RuntimeError) as exc:
if _is_sdk_disconnect_error(exc):
# Expected during client disconnect — suppress to avoid Sentry noise.
logger.debug(
"%s SDK client cleanup error suppressed (client disconnect): %s: %s",
log_prefix,
type(exc).__name__,
exc,
)
else:
raise
except GeneratorExit:
# GeneratorExit can propagate through __aexit__ — suppress it here
# since the generator is already being torn down.
logger.debug(
"%s SDK client cleanup GeneratorExit suppressed (client disconnect)",
log_prefix,
)
except Exception:
# Unexpected cleanup error — log at error level so Sentry captures it
# (via its logging integration), but don't propagate since we're in
# teardown and the caller cannot meaningfully handle this.
logger.error(
"%s Unexpected SDK client cleanup error",
log_prefix,
exc_info=True,
)
async def _iter_sdk_messages(
client: ClaudeSDKClient,
) -> AsyncGenerator[Any, None]:
@@ -567,6 +492,60 @@ def _resolve_sdk_model() -> str | None:
return model
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
2. **Direct Anthropic** — returns `{}`; subprocess inherits
`ANTHROPIC_API_KEY` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
_validate_claude_code_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
# `openrouter_active` checks the flag *and* credential presence.
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
# Strip /v1 suffix — SDK expects the base URL without a version path.
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
"""Sanitise a header value: strip newlines/whitespace and cap length."""
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -616,9 +595,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"""Convert SDK content blocks to transcript format.
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
a typed class for) are passed through verbatim to preserve them in the
transcript. Unknown typed block objects are logged and skipped.
Unknown block types are logged and skipped.
"""
result: list[dict[str, Any]] = []
for block in blocks or []:
@@ -650,9 +627,6 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"signature": block.signature,
}
)
elif isinstance(block, dict) and "type" in block:
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
result.append(block)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
@@ -1214,25 +1188,7 @@ async def _run_stream_attempt(
consecutive_empty_tool_calls = 0
# --- Intermediate persistence tracking ---
# Flush session messages to DB periodically so page reloads show progress
# during long-running turns (see incident d2f7cba3: 82-min turn lost on refresh).
_last_flush_time = time.monotonic()
_msgs_since_flush = 0
_FLUSH_INTERVAL_SECONDS = 30.0
_FLUSH_MESSAGE_THRESHOLD = 10
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
# suppress SDK cleanup errors that occur when the SSE client disconnects
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
# different async context/task than where the client was opened, which
# triggers:
# - ValueError: ContextVar token mismatch (AUTOGPT-SERVER-8BT)
# - RuntimeError: cancel scope in wrong task (AUTOGPT-SERVER-8BW)
# Both are harmless — the TCP connection is already dead.
sdk_client = ClaudeSDKClient(options=state.options)
client = await sdk_client.__aenter__()
try:
async with ClaudeSDKClient(options=state.options) as client:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
@@ -1490,38 +1446,8 @@ async def _run_stream_attempt(
model=sdk_msg.model,
)
# --- Intermediate persistence ---
# Flush session messages to DB periodically so page reloads
# show progress during long-running turns.
_msgs_since_flush += 1
now = time.monotonic()
if (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
):
try:
await asyncio.shield(upsert_chat_session(ctx.session))
logger.debug(
"%s Intermediate flush: %d messages "
"(msgs_since=%d, elapsed=%.1fs)",
ctx.log_prefix,
len(ctx.session.messages),
_msgs_since_flush,
now - _last_flush_time,
)
except Exception as flush_err:
logger.warning(
"%s Intermediate flush failed: %s",
ctx.log_prefix,
flush_err,
)
_last_flush_time = now
_msgs_since_flush = 0
if acc.stream_completed:
break
finally:
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
# --- Post-stream processing (only on success) ---
if state.adapter.has_unresolved_tool_calls:
@@ -1849,7 +1775,7 @@ async def stream_chat_completion_sdk(
)
# Fail fast when no API credentials are available at all.
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
if not config.api_key and not config.use_claude_code_subscription:
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY, "
@@ -2243,16 +2169,9 @@ async def stream_chat_completion_sdk(
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup errors are expected during client disconnect —
# log as warning rather than error to reduce Sentry noise.
# These are normally caught by _safe_close_sdk_client but
# can escape in edge cases (e.g. GeneratorExit timing).
if _is_sdk_disconnect_error(e):
logger.warning(
"%s SDK cleanup error (client disconnect): %s",
log_prefix,
error_msg,
)
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
else:
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
@@ -2274,11 +2193,10 @@ async def stream_chat_completion_sdk(
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
# Skip for CancelledError and SDK disconnect cleanup errors — these
# are not actionable by the user and the SSE connection is already dead.
is_cancellation = isinstance(
e, asyncio.CancelledError
) or _is_sdk_disconnect_error(e)
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
is_cancellation = isinstance(e, asyncio.CancelledError) or (
isinstance(e, RuntimeError) and "cancel scope" in str(e)
)
if not is_cancellation:
yield StreamError(errorText=display_msg, code=code)

View File

@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .service import (
_is_sdk_disconnect_error,
_prepare_file_attachments,
_resolve_sdk_model,
_safe_close_sdk_client,
)
from .service import _prepare_file_attachments, _resolve_sdk_model
@dataclass
@@ -504,111 +499,3 @@ class TestResolveSdkModel:
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _resolve_sdk_model() == "claude-opus-4-6"
# ---------------------------------------------------------------------------
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
# ---------------------------------------------------------------------------
class TestIsSdkDisconnectError:
"""Tests for _is_sdk_disconnect_error — identifies expected SDK cleanup errors."""
def test_cancel_scope_runtime_error(self):
"""RuntimeError about cancel scope in wrong task is a disconnect error."""
exc = RuntimeError(
"Attempted to exit cancel scope in a different task than it was entered in"
)
assert _is_sdk_disconnect_error(exc) is True
def test_context_var_value_error(self):
"""ValueError about ContextVar token mismatch is a disconnect error."""
exc = ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
assert _is_sdk_disconnect_error(exc) is True
def test_unrelated_runtime_error(self):
"""Unrelated RuntimeError should NOT be classified as disconnect error."""
exc = RuntimeError("something else went wrong")
assert _is_sdk_disconnect_error(exc) is False
def test_unrelated_value_error(self):
"""Unrelated ValueError should NOT be classified as disconnect error."""
exc = ValueError("invalid argument")
assert _is_sdk_disconnect_error(exc) is False
def test_other_exception_types(self):
"""Non-RuntimeError/ValueError should NOT be classified as disconnect error."""
assert _is_sdk_disconnect_error(TypeError("bad type")) is False
assert _is_sdk_disconnect_error(OSError("network down")) is False
assert _is_sdk_disconnect_error(asyncio.CancelledError()) is False
# ---------------------------------------------------------------------------
# _safe_close_sdk_client — suppress cleanup errors during disconnect
# ---------------------------------------------------------------------------
class TestSafeCloseSdkClient:
"""Tests for _safe_close_sdk_client — suppresses expected SDK cleanup errors."""
@pytest.mark.asyncio
async def test_clean_exit(self):
"""Normal __aexit__ (no error) should succeed silently."""
client = AsyncMock()
client.__aexit__ = AsyncMock(return_value=None)
await _safe_close_sdk_client(client, "[test]")
client.__aexit__.assert_awaited_once_with(None, None, None)
@pytest.mark.asyncio
async def test_cancel_scope_runtime_error_suppressed(self):
"""RuntimeError from cancel scope mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=RuntimeError(
"Attempted to exit cancel scope in a different task"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_context_var_value_error_suppressed(self):
"""ValueError from ContextVar token mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unexpected_exception_suppressed_with_error_log(self):
"""Unexpected exceptions should be caught (not propagated) but logged at error."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=OSError("unexpected"))
# Should NOT raise — unexpected errors are also suppressed to
# avoid crashing the generator during teardown. Logged at error
# level so Sentry captures them via its logging integration.
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_runtime_error_propagates(self):
"""Non-cancel-scope RuntimeError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=RuntimeError("something unrelated"))
with pytest.raises(RuntimeError, match="something unrelated"):
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_value_error_propagates(self):
"""Non-disconnect ValueError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
with pytest.raises(ValueError, match="invalid argument"):
await _safe_close_sdk_client(client, "[test]")

View File

@@ -1,822 +0,0 @@
"""Tests for thinking/redacted_thinking block preservation.
Validates the fix for the Anthropic API error:
"thinking or redacted_thinking blocks in the latest assistant message
cannot be modified. These blocks must remain as they were in the
original response."
The API requires that thinking blocks in the LAST assistant message are
preserved value-identical. Older assistant messages may have thinking blocks
stripped entirely. This test suite covers:
1. _flatten_assistant_content — strips thinking from older messages
2. compact_transcript — preserves last assistant's thinking blocks
3. response_adapter — handles ThinkingBlock without error
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
from backend.copilot.response_model import (
StreamStartStep,
StreamTextDelta,
StreamTextStart,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import (
_find_last_assistant_entry,
_flatten_assistant_content,
_messages_to_transcript,
_rechain_tail,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
# ---------------------------------------------------------------------------
# Fixtures: realistic thinking block content
# ---------------------------------------------------------------------------
THINKING_BLOCK = {
"type": "thinking",
"thinking": "Let me analyze the user's request carefully...",
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
}
REDACTED_THINKING_BLOCK = {
"type": "redacted_thinking",
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
}
def _make_thinking_transcript() -> str:
"""Build a transcript with thinking blocks in multiple assistant turns.
Layout:
User 1 → Assistant 1 (thinking + text + tool_use)
User 2 (tool_result) → Assistant 2 (thinking + text)
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
"""
return build_structured_transcript(
[
("user", "What files are in this project?"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "I should list the files.",
"signature": "sig_old_1",
},
{"type": "text", "text": "Let me check the files."},
{
"type": "tool_use",
"id": "tu1",
"name": "list_files",
"input": {"path": "/"},
},
],
),
("user", "Here are the files: a.py, b.py"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "Good, I see two Python files.",
"signature": "sig_old_2",
},
{"type": "text", "text": "I found a.py and b.py."},
],
),
("user", "Tell me about a.py"),
(
"assistant",
[
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "a.py contains the main entry point."},
],
),
]
)
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
"""Extract the content blocks of the last assistant entry in a transcript."""
last_content = None
for line in transcript_jsonl.strip().split("\n"):
entry = json.loads(line)
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_content = msg.get("content")
return last_content
# ---------------------------------------------------------------------------
# _find_last_assistant_entry — unit tests
# ---------------------------------------------------------------------------
class TestFindLastAssistantEntry:
def test_splits_at_last_assistant(self):
"""Prefix contains everything before last assistant; tail starts at it."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
assert len(prefix) == 3
assert len(tail) == 1
def test_no_assistant_returns_all_in_prefix(self):
"""When there's no assistant, all lines are in prefix, tail is empty."""
transcript = build_structured_transcript(
[("user", "Hello"), ("user", "Another question")]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 2
assert tail == []
def test_assistant_at_index_zero(self):
"""When assistant is the first entry, prefix is empty."""
transcript = build_structured_transcript(
[("assistant", [{"type": "text", "text": "Start"}])]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert prefix == []
assert len(tail) == 1
def test_trailing_user_included_in_tail(self):
"""User message after last assistant is part of the tail."""
transcript = build_structured_transcript(
[
("user", "Q1"),
("assistant", [{"type": "text", "text": "A1"}]),
("user", "Q2"),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # first user
assert len(tail) == 2 # last assistant + trailing user
def test_multi_entry_turn_fully_preserved(self):
"""An assistant turn spanning multiple JSONL entries (same message.id)
must be entirely in the tail, not split across prefix and tail."""
# Build manually because build_structured_transcript generates unique ids
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [THINKING_BLOCK],
"stop_reason": None,
"stop_sequence": None,
},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {},
},
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
# Both assistant entries share msg_same_turn → both in tail
assert len(prefix) == 1 # only the user entry
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
def test_no_message_id_preserves_last_assistant(self):
"""When the last assistant entry has no message.id, it should still
be preserved in the tail (fail closed) rather than being compressed."""
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # user entry
assert len(tail) == 1 # assistant entry preserved
# ---------------------------------------------------------------------------
# _rechain_tail — UUID chain patching
# ---------------------------------------------------------------------------
class TestRechainTail:
def test_patches_first_entry_parentuuid(self):
"""First tail entry's parentUuid should point to last prefix uuid."""
prefix = _messages_to_transcript(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
)
# Get the last uuid from the prefix
last_prefix_uuid = None
for line in prefix.strip().split("\n"):
entry = json.loads(line)
last_prefix_uuid = entry.get("uuid")
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "tail-a1",
"parentUuid": "old-parent",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Tail msg"}],
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["parentUuid"] == last_prefix_uuid
assert entry["uuid"] == "tail-a1" # uuid preserved
def test_chains_multiple_tail_entries(self):
"""Subsequent tail entries chain to each other."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old1",
"message": {"role": "assistant", "content": []},
}
),
json.dumps(
{
"type": "user",
"uuid": "t2",
"parentUuid": "old2",
"message": {"role": "user", "content": "Follow-up"},
}
),
]
result = _rechain_tail(prefix, tail_lines)
entries = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(entries) == 2
# Second entry's parentUuid should be first entry's uuid
assert entries[1]["parentUuid"] == "t1"
def test_empty_tail_returns_empty(self):
"""No tail entries → empty string."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
assert _rechain_tail(prefix, []) == ""
def test_preserves_message_content_verbatim(self):
"""Tail message content (including thinking blocks) must not be modified."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
original_content = [
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "Response"},
]
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old",
"message": {
"role": "assistant",
"content": original_content,
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["message"]["content"] == original_content
# ---------------------------------------------------------------------------
# _flatten_assistant_content — thinking blocks
# ---------------------------------------------------------------------------
class TestFlattenThinkingBlocks:
def test_thinking_blocks_are_stripped(self):
"""Thinking blocks should not appear in flattened text for compression."""
blocks = [
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
{"type": "text", "text": "Hello user"},
]
result = _flatten_assistant_content(blocks)
assert "secret thoughts" not in result
assert "Hello user" in result
def test_redacted_thinking_blocks_are_stripped(self):
"""Redacted thinking blocks should not appear in flattened text."""
blocks = [
{"type": "redacted_thinking", "data": "encrypted_data"},
{"type": "text", "text": "Response text"},
]
result = _flatten_assistant_content(blocks)
assert "encrypted_data" not in result
assert "Response text" in result
def test_thinking_only_message_flattens_to_empty(self):
"""A message with only thinking blocks flattens to empty string."""
blocks = [
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
]
result = _flatten_assistant_content(blocks)
assert result == ""
def test_mixed_thinking_text_tool(self):
"""Mixed blocks: only text and tool_use survive flattening."""
blocks = [
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
{"type": "redacted_thinking", "data": "xyz"},
{"type": "text", "text": "I'll read the file."},
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
]
result = _flatten_assistant_content(blocks)
assert "hmm" not in result
assert "xyz" not in result
assert "I'll read the file." in result
assert "[tool_use: Read]" in result
# ---------------------------------------------------------------------------
# compact_transcript — thinking block preservation
# ---------------------------------------------------------------------------
class TestCompactTranscriptThinkingBlocks:
"""Verify that compact_transcript preserves thinking blocks in the
last assistant message while stripping them from older messages."""
@pytest.mark.asyncio
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
"""After compaction, the last assistant entry must retain its
original thinking and redacted_thinking blocks verbatim."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
last_content = _last_assistant_content(result)
assert last_content is not None, "No assistant entry found"
assert isinstance(last_content, list)
# The last assistant must have the thinking blocks preserved
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block missing from last assistant message"
assert (
"redacted_thinking" in block_types
), "redacted_thinking block missing from last assistant message"
assert "text" in block_types
# Verify the thinking block content is value-identical
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
assert len(thinking_blocks) == 1
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
assert len(redacted_blocks) == 1
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
@pytest.mark.asyncio
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
"""Older assistant messages should NOT retain thinking blocks
after compaction (they're compressed into summaries)."""
transcript = _make_thinking_transcript()
# The compressor will receive messages where older assistant
# entries have already had thinking blocks stripped.
captured_messages: list[dict] = []
async def mock_compression(messages, model, log_prefix):
captured_messages.extend(messages)
return type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": messages,
"original_token_count": 800,
"token_count": 400,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
side_effect=mock_compression,
):
await compact_transcript(transcript, model="test-model")
# Check that the messages sent to compression don't contain
# thinking content from older assistant messages
for msg in captured_messages:
if msg["role"] == "assistant":
content = msg.get("content", "")
assert (
"I should list the files." not in content
), "Old thinking block content leaked into compression input"
assert (
"Good, I see two Python files." not in content
), "Old thinking block content leaked into compression input"
@pytest.mark.asyncio
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
"""When the last entry is a user message, the last *assistant*
message's thinking blocks should still be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "Hi there"},
],
),
("user", "Follow-up question"),
]
)
# The compressor only receives the prefix (1 user message); the
# tail (assistant + trailing user) is preserved verbatim.
compacted_msgs = [
{"role": "user", "content": "Hello"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 400,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block lost from last assistant despite trailing user msg"
@pytest.mark.asyncio
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
"""When there's only one assistant message (which is also the last),
its thinking blocks must be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "World"},
],
),
]
)
compacted_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert "thinking" in block_types
@pytest.mark.asyncio
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
"""After compaction, the first tail entry's parentUuid must point to
the last entry in the compressed prefix — not its original parent."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
entries = [json.loads(ln) for ln in lines]
# Find the boundary: the compressed prefix ends just before the
# first tail entry (last assistant in original transcript).
tail_start = None
for i, entry in enumerate(entries):
msg = entry.get("message", {})
if isinstance(msg.get("content"), list):
# Structured content = preserved tail entry
tail_start = i
break
assert tail_start is not None, "Could not find preserved tail entry"
assert tail_start > 0, "Tail should not be the first entry"
# The tail entry's parentUuid must be the uuid of the preceding entry
prefix_last_uuid = entries[tail_start - 1]["uuid"]
tail_first_parent = entries[tail_start]["parentUuid"]
assert tail_first_parent == prefix_last_uuid, (
f"Tail parentUuid {tail_first_parent!r} != "
f"last prefix uuid {prefix_last_uuid!r}"
)
@pytest.mark.asyncio
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
"""Compaction should still work normally when there are no thinking
blocks in the transcript."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
compacted_msgs = [
{"role": "user", "content": "[summary]"},
{"role": "assistant", "content": "Summary"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 50,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
# Verify last assistant content is preserved even without thinking blocks
last_content = _last_assistant_content(result)
assert last_content is not None
assert last_content == [{"type": "text", "text": "Details"}]
# ---------------------------------------------------------------------------
# _transcript_to_messages — thinking block handling
# ---------------------------------------------------------------------------
class TestTranscriptToMessagesThinking:
def test_thinking_blocks_excluded_from_flattened_content(self):
"""When _transcript_to_messages flattens content, thinking block
text should not leak into the message content string."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "SECRET_THOUGHT",
"signature": "sig",
},
{"type": "text", "text": "Visible response"},
],
),
]
)
messages = _transcript_to_messages(transcript)
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
assert "SECRET_THOUGHT" not in assistant_msg["content"]
assert "Visible response" in assistant_msg["content"]
# ---------------------------------------------------------------------------
# response_adapter — ThinkingBlock handling
# ---------------------------------------------------------------------------
class TestResponseAdapterThinkingBlock:
def test_thinking_block_does_not_crash(self):
"""ThinkingBlock in AssistantMessage should not cause an error."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="Let me think about this...",
signature="sig_test_123",
),
TextBlock(text="Here is my response."),
],
model="claude-test",
)
results = adapter.convert_message(msg)
# Should produce stream events for text only, no crash
types = [type(r) for r in results]
assert StreamStartStep in types
assert StreamTextStart in types or StreamTextDelta in types
def test_thinking_block_does_not_emit_stream_events(self):
"""ThinkingBlock should NOT produce any StreamTextDelta events
containing thinking content."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="My secret thoughts",
signature="sig_test_456",
),
TextBlock(text="Public response"),
],
model="claude-test",
)
results = adapter.convert_message(msg)
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
for delta in text_deltas:
assert "secret thoughts" not in (delta.delta or "")
# ---------------------------------------------------------------------------
# _format_sdk_content_blocks — redacted_thinking handling
# ---------------------------------------------------------------------------
class TestFormatSdkContentBlocks:
def test_thinking_block_preserved(self):
"""ThinkingBlock should be serialized with type, thinking, and signature."""
blocks = [
ThinkingBlock(thinking="My thoughts", signature="sig123"),
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == {
"type": "thinking",
"thinking": "My thoughts",
"signature": "sig123",
}
assert result[1] == {"type": "text", "text": "Response"}
def test_raw_dict_redacted_thinking_preserved(self):
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
blocks = [
raw_block,
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == raw_block
assert result[1] == {"type": "text", "text": "Response"}

View File

@@ -12,7 +12,6 @@ from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
SDK_DISALLOWED_TOOLS,
_text_from_mcp_result,
cancel_pending_tool_tasks,
create_tool_handler,
@@ -773,19 +772,3 @@ class TestFiveConcurrentPrelaunchAllComplete:
assert result["isError"] is False, f"Result {i} should not be an error"
text = result["content"][0]["text"]
assert "done-" in text, f"Result {i} missing expected output: {text}"
# ---------------------------------------------------------------------------
# SDK_DISALLOWED_TOOLS
# ---------------------------------------------------------------------------
class TestSDKDisallowedTools:
"""Verify that dangerous SDK built-in tools are in the disallowed list."""
def test_bash_tool_is_disallowed(self):
assert "Bash" in SDK_DISALLOWED_TOOLS
def test_webfetch_tool_is_disallowed(self):
"""WebFetch is disallowed due to SSRF risk."""
assert "WebFetch" in SDK_DISALLOWED_TOOLS

View File

@@ -605,31 +605,20 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
silently dropped — they carry no useful context for compression
summaries and must not leak into compacted transcripts (the Anthropic
API requires thinking blocks in the last assistant message to be
value-identical to the original response; including stale thinking
text would violate that constraint).
This is intentional: ``compress_context`` requires plain text for
token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript
was already too large for the model.
placeholders. This is intentional: ``compress_context`` requires plain
text for token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript was
already too large for the model — a summarized plain-text version is
better than no context at all.
"""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
btype = block.get("type", "")
if btype in _THINKING_BLOCK_TYPES:
continue
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
@@ -816,68 +805,6 @@ async def _run_compression(
)
def _find_last_assistant_entry(
content: str,
) -> tuple[list[str], list[str]]:
"""Split JSONL lines into (compressible_prefix, preserved_tail).
The tail starts at the **first** entry of the last assistant turn and
includes everything after it (typically trailing user messages). An
assistant turn can span multiple consecutive JSONL entries sharing the
same ``message.id`` (e.g., a thinking entry followed by a tool_use
entry). All entries of the turn are preserved verbatim.
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
blocks in the **last** assistant message remain value-identical to the
original response (the API validates parsed signature values, not raw
JSON bytes). By excluding the entire turn from compression we
guarantee those blocks are never altered.
Returns ``(all_lines, [])`` when no assistant entry is found.
"""
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
# Parse all lines once to avoid double JSON deserialization.
# json.loads with fallback=None returns Any; non-dict entries are
# safely skipped by the isinstance(entry, dict) guards below.
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
# Reverse scan: find the message.id and index of the last assistant entry.
last_asst_msg_id: str | None = None
last_asst_idx: int | None = None
for i in range(len(parsed) - 1, -1, -1):
entry = parsed[i]
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_asst_idx = i
last_asst_msg_id = msg.get("id")
break
if last_asst_idx is None:
return lines, []
# If the assistant entry has no message.id, fall back to preserving
# from that single entry onward — safer than compressing everything.
if last_asst_msg_id is None:
return lines[:last_asst_idx], lines[last_asst_idx:]
# Forward scan: find the first entry of this turn (same message.id).
first_turn_idx: int | None = None
for i, entry in enumerate(parsed):
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
first_turn_idx = i
break
if first_turn_idx is None:
return lines, []
return lines[:first_turn_idx], lines[first_turn_idx:]
async def compact_transcript(
content: str,
*,
@@ -889,50 +816,42 @@ async def compact_transcript(
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
The **last assistant entry** (and any entries after it) are preserved
verbatim — never flattened or compressed. The Anthropic API requires
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
message to be value-identical to the original response (the API
validates parsed signature values, not raw JSON bytes); compressing
them would destroy the cryptographic signatures and cause
``invalid_request_error``.
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
is flattened to plain text for compression. This matches the fidelity of
the Plan C (DB compression) fallback path, where
``_format_conversation_context`` similarly renders tool calls as
``You called tool: name(args)`` and results as ``Tool result: ...``.
Neither path preserves structured API content blocks — the compacted
context serves as text history for the LLM, which creates proper
structured tool calls going forward.
Structured content in *older* assistant entries (``tool_use`` blocks,
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
plain text for compression. This matches the fidelity of the Plan C
(DB compression) fallback path.
Images are per-turn attachments loaded from workspace storage by file ID
(via ``_prepare_file_attachments``), not part of the conversation history.
They are re-attached each turn and are unaffected by compaction.
Returns the compacted JSONL string, or ``None`` on failure.
See also:
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
lists for pre-query DB history.
lists for pre-query DB history. Both share ``compress_context()``
but operate on different input formats (JSONL transcript entries
here vs. ChatMessage dicts there).
"""
prefix_lines, tail_lines = _find_last_assistant_entry(content)
# Build the JSONL string for the compressible prefix
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
messages = _transcript_to_messages(prefix_content) if prefix_content else []
if len(messages) + len(tail_lines) < 2:
total = len(messages) + len(tail_lines)
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
return None
if not messages:
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, model, log_prefix)
if not result.was_compacted:
# Compressor says it's within budget, but the SDK rejected it.
# Return None so the caller falls through to DB fallback.
logger.warning(
"%s Compressor reports within budget but SDK rejected — "
"signalling failure",
log_prefix,
)
return None
if not result.messages:
logger.warning("%s Compressor returned empty messages", log_prefix)
return None
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
@@ -941,29 +860,7 @@ async def compact_transcript(
result.messages_summarized,
result.messages_dropped,
)
compressed_part = _messages_to_transcript(result.messages)
# Re-append the preserved tail (last assistant + trailing entries)
# with parentUuid patched to chain onto the compressed prefix.
tail_part = _rechain_tail(compressed_part, tail_lines)
compacted = compressed_part + tail_part
if len(compacted) >= len(content):
# Byte count can increase due to preserved tail entries
# (thinking blocks, JSON overhead) even when token count
# decreased. Log a warning but still return — the API
# validates tokens not bytes, and the caller falls through
# to DB fallback if the transcript is still too large.
logger.warning(
"%s Compacted transcript (%d bytes) is not smaller than "
"original (%d bytes) — may still reduce token count",
log_prefix,
len(compacted),
len(content),
)
# Authoritative validation — the caller (_reduce_context) also
# validates, but this is the canonical check that guarantees we
# never return a malformed transcript from this function.
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
@@ -973,43 +870,3 @@ async def compact_transcript(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
last entry in the compressed prefix. Subsequent tail entries are
rechained to point to their predecessor in the tail — their original
``parentUuid`` values may reference entries that were compressed away.
"""
if not tail_lines:
return ""
# Find the last uuid in the compressed prefix
last_prefix_uuid = ""
for line in reversed(compressed_prefix.strip().split("\n")):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if isinstance(entry, dict) and "uuid" in entry:
last_prefix_uuid = entry["uuid"]
break
result_lines: list[str] = []
prev_uuid: str | None = None
for i, line in enumerate(tail_lines):
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
# Safety guard: _find_last_assistant_entry already filters empty
# lines, and well-formed JSONL always parses to dicts. Non-dict
# lines are passed through unchanged; prev_uuid is intentionally
# NOT updated so the next dict entry chains to the last known uuid.
result_lines.append(line)
continue
if i == 0:
entry["parentUuid"] = last_prefix_uuid
elif prev_uuid is not None:
entry["parentUuid"] = prev_uuid
prev_uuid = entry.get("uuid")
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"

View File

@@ -30,7 +30,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id, dry_run=False)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -26,7 +26,6 @@ import orjson
from redis.exceptions import RedisError
from backend.api.model import CopilotCompletionPayload
from backend.data.db_accessors import chat_db
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
@@ -112,14 +111,6 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
pre-dates the turn_id field (backward compat for in-flight sessions).
"""
created_at = datetime.now(timezone.utc)
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
except (ValueError, TypeError):
pass
return ActiveSession(
session_id=meta.get("session_id", "") or session_id,
user_id=meta.get("user_id", "") or None,
@@ -128,7 +119,6 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=meta.get("status", "running"), # type: ignore[arg-type]
created_at=created_at,
)
@@ -221,21 +211,9 @@ async def create_session(
return session
_meta_ttl_refresh_at: dict[str, float] = {}
"""Tracks the last time the session meta key TTL was refreshed.
Used by `publish_chunk` to avoid refreshing on every single chunk
(expensive). Refreshes at most once every 60 seconds per session.
"""
_META_TTL_REFRESH_INTERVAL = 60 # seconds
async def publish_chunk(
turn_id: str,
chunk: StreamBaseResponse,
*,
session_id: str | None = None,
) -> str:
"""Publish a chunk to Redis Stream.
@@ -244,9 +222,6 @@ async def publish_chunk(
Args:
turn_id: Turn ID (per-turn UUID) identifying the stream
chunk: The stream response chunk to publish
session_id: Chat session ID — when provided, the session meta key
TTL is refreshed periodically to prevent expiration during
long-running turns (see SECRT-2178).
Returns:
The Redis Stream message ID
@@ -280,23 +255,6 @@ async def publish_chunk(
# Set TTL on stream to match session metadata TTL
await redis.expire(stream_key, config.stream_ttl)
# Periodically refresh session-related TTLs so they don't expire
# during long-running turns. Without this, turns exceeding stream_ttl
# (default 1h) lose their "running" status and stream data, making
# the session invisible to the resume endpoint (empty on page reload).
# Both meta key AND stream key are refreshed: the stream key's expire
# above only fires when publish_chunk is called, but during long
# sub-agent gaps (task_progress events don't produce chunks), neither
# key gets refreshed.
if session_id:
now = time.perf_counter()
last_refresh = _meta_ttl_refresh_at.get(session_id, 0)
if now - last_refresh >= _META_TTL_REFRESH_INTERVAL:
meta_key = _get_session_meta_key(session_id)
await redis.expire(meta_key, config.stream_ttl)
await redis.expire(stream_key, config.stream_ttl)
_meta_ttl_refresh_at[session_id] = now
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
@@ -363,7 +321,7 @@ async def stream_and_publish(
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event, session_id=session_id)
await publish_chunk(turn_id, event)
except (RedisError, ConnectionError, OSError):
if not publish_failed_once:
publish_failed_once = True
@@ -832,9 +790,6 @@ async def mark_session_completed(
# Atomic compare-and-swap: only update if status is "running"
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
_meta_ttl_refresh_at.pop(session_id, None)
if result == 0:
logger.debug(f"Session {session_id} already completed/failed, skipping")
return False
@@ -847,33 +802,6 @@ async def mark_session_completed(
f"Failed to publish error event for session {session_id}: {e}"
)
# Compute wall-clock duration from session created_at.
# Only persist when (a) the session completed successfully and
# (b) created_at was actually present in Redis meta (not a fallback).
duration_ms: int | None = None
if meta and not error_message:
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = datetime.now(timezone.utc) - created_at
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
except (ValueError, TypeError):
logger.warning(
"Failed to compute session duration for %s (created_at=%r)",
session_id,
created_at_raw,
)
# Persist duration on the last assistant message
if duration_ms is not None:
try:
await chat_db().set_turn_duration(session_id, duration_ms)
except Exception as e:
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
# Publish StreamFinish AFTER status is set to "completed"/"failed".
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.

View File

@@ -102,6 +102,7 @@ async def setup_test_data(server):
"value": "",
"advanced": False,
"description": "Test input field",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)
@@ -241,6 +242,7 @@ async def setup_llm_test_data(server):
"value": "",
"advanced": False,
"description": "Prompt for the LLM",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)
@@ -394,6 +396,7 @@ async def setup_firecrawl_test_data(server):
"value": "",
"advanced": False,
"description": "URL for Firecrawl to scrape",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)

View File

@@ -68,9 +68,6 @@ class AddUnderstandingTool(BaseTool):
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
Note: This tool accepts **kwargs because its parameters are derived
dynamically from the BusinessUnderstandingInput model schema.
"""
session_id = session.session_id
@@ -80,21 +77,23 @@ class AddUnderstandingTool(BaseTool):
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
filtered = {k: v for k, v in kwargs.items() if k in valid_fields}
# Check if any data was provided
if not any(v is not None for v in filtered.values()):
if not any(v is not None for v in kwargs.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
input_data = BusinessUnderstandingInput(**filtered)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
input_data = BusinessUnderstandingInput(
**{k: v for k, v in kwargs.items() if k in valid_fields}
)
# Track which fields were updated
updated_fields = [k for k, v in filtered.items() if v is not None]
updated_fields = [
k for k, v in kwargs.items() if k in valid_fields and v is not None
]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(

View File

@@ -180,14 +180,12 @@ async def _save_browser_state(
"""
try:
# Gather state in parallel
(
(rc_url, url_out, _),
(rc_ck, ck_out, _),
(rc_ls, ls_out, _),
) = await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
@@ -450,8 +448,6 @@ class BrowserNavigateTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
url: str = "",
wait_for: str = "networkidle",
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
@@ -460,8 +456,8 @@ class BrowserNavigateTool(BaseTool):
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url = url.strip()
wait_for = wait_for or "networkidle"
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
@@ -616,10 +612,6 @@ class BrowserActTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
action: str = "",
target: str = "",
value: str = "",
direction: str = "down",
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
@@ -628,10 +620,10 @@ class BrowserActTool(BaseTool):
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action = action.strip()
target = target.strip()
value = value.strip()
direction = direction.strip()
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
@@ -785,8 +777,6 @@ class BrowserScreenshotTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
annotate: bool | str = True,
filename: str = "screenshot.png",
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
@@ -796,12 +786,12 @@ class BrowserScreenshotTool(BaseTool):
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = annotate
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename = filename.strip()
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod

View File

@@ -4,8 +4,6 @@ import logging
import re
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
@@ -1538,8 +1536,8 @@ class AgentFixer:
for link in links:
sink_name = link.get("sink_name", "")
if DICT_SPLIT in sink_name:
parent, child = sink_name.split(DICT_SPLIT, 1)
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
# Check if child is a numeric index (invalid for _#_ notation)
if child.isdigit():

View File

@@ -4,8 +4,6 @@ import re
import uuid
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .blocks import get_blocks_as_dicts
__all__ = [
@@ -53,8 +51,8 @@ def generate_uuid() -> str:
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if DICT_SPLIT in name:
parent, child = name.split(DICT_SPLIT, 1)
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict

View File

@@ -5,8 +5,6 @@ import logging
import re
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
@@ -258,6 +256,95 @@ class AgentValidator:
return valid
def validate_nested_sink_links(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
node_lookup: dict[str, dict[str, Any]] | None = None,
) -> bool:
"""
Validate nested sink links (links with _#_ notation).
Returns True if all nested links are valid, False otherwise.
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
if node_lookup is None:
node_lookup = self._build_node_lookup(agent)
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
sink_id = link.get("sink_id")
if not sink_name or not sink_id:
continue
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(sink_id)
if not sink_node:
continue
block_id = sink_node.get("block_id")
input_props = block_input_schemas.get(block_id, {})
parent_schema = input_props.get(parent)
if not parent_schema:
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' for "
f"node '{sink_id}' (block "
f"'{block_name}' - {block_id}): Parent property "
f"'{parent}' does not exist in the block's "
f"input schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
# Check anyOf for additionalProperties
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' "
f"for node '{link.get('sink_id', '')}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' schema. Available "
f"properties: "
f"{list(parent_schema.get('properties', {}).keys())}"
)
valid = False
return valid
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
"""
Validate that prompt parameters do not contain spaces in double curly
@@ -384,8 +471,8 @@ class AgentValidator:
output_props = block_output_schemas.get(block_id, {})
# Handle nested source names (with _#_ notation)
if DICT_SPLIT in source_name:
parent, child = source_name.split(DICT_SPLIT, 1)
if "_#_" in source_name:
parent, child = source_name.split("_#_", 1)
parent_schema = output_props.get(parent)
if not parent_schema:
@@ -466,195 +553,6 @@ class AgentValidator:
return valid
def validate_sink_input_existence(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
node_lookup: dict[str, dict[str, Any]] | None = None,
) -> bool:
"""
Validate that all sink_names in links and input_default keys in nodes
exist in the corresponding block's input schema.
Checks that for each link the sink_name references a valid input
property in the sink block's inputSchema, and that every key in a
node's input_default is a recognised input property. Also handles
nested inputs with _#_ notation and dynamic schemas for
AgentExecutorBlock.
Args:
agent: The agent dictionary to validate
blocks: List of available blocks with their schemas
node_lookup: Optional pre-built node-id → node dict
Returns:
True if all sink input fields exist, False otherwise
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
if node_lookup is None:
node_lookup = self._build_node_lookup(agent)
def get_input_props(node: dict[str, Any]) -> dict[str, Any]:
block_id = node.get("block_id", "")
if block_id == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default", {})
dynamic_input_schema = input_default.get("input_schema", {})
if not isinstance(dynamic_input_schema, dict):
dynamic_input_schema = {}
dynamic_props = dynamic_input_schema.get("properties", {})
if not isinstance(dynamic_props, dict):
dynamic_props = {}
static_props = block_input_schemas.get(block_id, {})
return {**static_props, **dynamic_props}
return block_input_schemas.get(block_id, {})
def check_nested_input(
input_props: dict[str, Any],
field_name: str,
context: str,
block_name: str,
block_id: str,
) -> bool:
parent, child = field_name.split(DICT_SPLIT, 1)
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"{context}: Parent property '{parent}' does not "
f"exist in block '{block_name}' ({block_id}) input "
f"schema."
)
return False
allows_additional = parent_schema.get("additionalProperties", False)
# Only anyOf is checked here because Pydantic's JSON schema
# emits optional/union fields via anyOf. allOf and oneOf are
# not currently used by any block's dict-typed inputs, so
# false positives from them are not a concern in practice.
if not allows_additional and "anyOf" in parent_schema:
for schema_option in parent_schema.get("anyOf", []):
if not isinstance(schema_option, dict):
continue
if schema_option.get("additionalProperties"):
allows_additional = True
break
items_schema = schema_option.get("items")
if isinstance(items_schema, dict) and items_schema.get(
"additionalProperties"
):
allows_additional = True
break
if not allows_additional:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
available = (
list(parent_schema.get("properties", {}).keys())
if isinstance(parent_schema, dict)
else []
)
self.add_error(
f"{context}: Child property '{child}' does not "
f"exist in parent '{parent}' of block "
f"'{block_name}' ({block_id}) input schema. "
f"Available properties: {available}"
)
return False
return True
for link in agent.get("links", []):
sink_id = link.get("sink_id")
sink_name = link.get("sink_name", "")
link_id = link.get("id", "Unknown")
if not sink_name:
# Missing sink_name is caught by validate_data_type_compatibility
continue
sink_node = node_lookup.get(sink_id)
if not sink_node:
# Already caught by validate_link_node_references
continue
block_id = sink_node.get("block_id", "")
block_name = block_names.get(block_id, "Unknown Block")
input_props = get_input_props(sink_node)
context = (
f"Invalid sink input field '{sink_name}' in link "
f"'{link_id}' to node '{sink_id}'"
)
if DICT_SPLIT in sink_name:
if not check_nested_input(
input_props, sink_name, context, block_name, block_id
):
valid = False
else:
if sink_name not in input_props:
available_inputs = list(input_props.keys())
self.add_error(
f"{context} (block '{block_name}' - {block_id}): "
f"Input property '{sink_name}' does not exist in "
f"the block's input schema. "
f"Available inputs: {available_inputs}"
)
valid = False
for node in agent.get("nodes", []):
node_id = node.get("id")
block_id = node.get("block_id", "")
block_name = block_names.get(block_id, "Unknown Block")
input_default = node.get("input_default", {})
if not isinstance(input_default, dict) or not input_default:
continue
if (
block_id not in block_input_schemas
and block_id != AGENT_EXECUTOR_BLOCK_ID
):
continue
input_props = get_input_props(node)
for key in input_default:
if key == "credentials":
continue
context = (
f"Node '{node_id}' (block '{block_name}' - {block_id}) "
f"has unknown input_default key '{key}'"
)
if DICT_SPLIT in key:
if not check_nested_input(
input_props, key, context, block_name, block_id
):
valid = False
else:
if key not in input_props:
available_inputs = list(input_props.keys())
self.add_error(
f"{context} which does not exist in the "
f"block's input schema. "
f"Available inputs: {available_inputs}"
)
valid = False
return valid
def validate_io_blocks(self, agent: AgentDict) -> bool:
"""
Validate that the agent has at least one AgentInputBlock and one
@@ -1101,12 +999,12 @@ class AgentValidator:
self.validate_data_type_compatibility(agent, blocks, node_lookup),
),
(
"Source output existence",
self.validate_source_output_existence(agent, blocks, node_lookup),
"Nested sink links",
self.validate_nested_sink_links(agent, blocks, node_lookup),
),
(
"Sink input existence",
self.validate_sink_input_existence(agent, blocks, node_lookup),
"Source output existence",
self.validate_source_output_existence(agent, blocks, node_lookup),
),
(
"Prompt double curly braces spaces",

View File

@@ -331,6 +331,43 @@ class TestValidatePromptDoubleCurlyBracesSpaces:
assert any("spaces" in e for e in v.errors)
# ============================================================================
# validate_nested_sink_links
# ============================================================================
class TestValidateNestedSinkLinks:
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is True
def test_invalid_parent_fails(self):
v = AgentValidator()
block = _make_block(block_id="b1")
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_block_schemas
# ============================================================================
@@ -558,28 +595,11 @@ class TestValidate:
input_block = _make_block(
block_id=AGENT_INPUT_BLOCK_ID,
name="AgentInputBlock",
input_schema={
"properties": {
"name": {"type": "string"},
"title": {"type": "string"},
"value": {},
"description": {"type": "string"},
},
"required": ["name"],
},
output_schema={"properties": {"result": {}}},
)
output_block = _make_block(
block_id=AGENT_OUTPUT_BLOCK_ID,
name="AgentOutputBlock",
input_schema={
"properties": {
"name": {"type": "string"},
"title": {"type": "string"},
"value": {},
},
"required": ["name"],
},
)
input_node = _make_node(
node_id="n-in",
@@ -630,201 +650,6 @@ class TestValidate:
assert "AgentOutputBlock" in error_message
class TestValidateSinkInputExistence:
"""Tests for validate_sink_input_existence."""
def test_valid_sink_name_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src", source_name="out", sink_id="n1", sink_name="url"
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_invalid_sink_name_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src", source_name="out", sink_id="n1", sink_name="nonexistent"
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("nonexistent" in e for e in v.errors)
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="config_#_key",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_invalid_nested_child_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="config_#_missing",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is False
def test_unknown_input_default_key_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(
node_id="n1", block_id="b1", input_default={"nonexistent_key": "value"}
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("nonexistent_key" in e for e in v.errors)
def test_credentials_key_skipped(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={
"url": "http://example.com",
"credentials": {"api_key": "x"},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_agent_executor_dynamic_schema_passes(self):
v = AgentValidator()
block = _make_block(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_schema={
"properties": {
"graph_id": {"type": "string"},
"input_schema": {"type": "object"},
},
"required": ["graph_id"],
},
)
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": "abc",
"input_schema": {
"properties": {"query": {"type": "string"}},
"required": [],
},
},
)
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="query",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_input_default_nested_invalid_child_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={"config_#_invalid_child": "value"},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("invalid_child" in e for e in v.errors)
def test_input_default_nested_valid_child_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={"config_#_key": "value"},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is True
class TestValidateMCPToolBlocks:
"""Tests for validate_mcp_tool_blocks."""

View File

@@ -411,12 +411,7 @@ class AgentOutputTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool.
Note: This tool accepts **kwargs and delegates to AgentOutputInput
for validation because the parameter set has cross-field validators
defined in the Pydantic model.
"""
"""Execute the agent_output tool."""
session_id = session.session_id
# Parse and validate input

View File

@@ -76,8 +76,6 @@ class BashExecTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
command: str = "",
timeout: int = 30,
**kwargs: Any,
) -> ToolResponseBase:
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
@@ -90,8 +88,8 @@ class BashExecTool(BaseTool):
"""
session_id = session.session_id if session else None
command = command.strip()
timeout = int(timeout)
command: str = (kwargs.get("command") or "").strip()
timeout: int = int(kwargs.get("timeout", 30))
if not command:
return ErrorResponse(

View File

@@ -115,9 +115,6 @@ class ConnectIntegrationTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
provider: str = "",
reason: str = "",
scopes: list[str] | None = None,
**kwargs: Any,
) -> ToolResponseBase:
"""Build and return a :class:`SetupRequirementsResponse` for the requested provider.
@@ -131,10 +128,12 @@ class ConnectIntegrationTool(BaseTool):
"""
_ = user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider = (provider or "").strip().lower()
reason = (reason or "").strip()[:500] # cap LLM-controlled text
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (scopes or []) if str(s).strip()
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
entry = SUPPORTED_PROVIDERS.get(provider)
@@ -142,7 +141,8 @@ class ConnectIntegrationTool(BaseTool):
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDERS)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. Supported providers: {supported}."
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
@@ -153,11 +153,11 @@ class ConnectIntegrationTool(BaseTool):
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = entry["default_scopes"]
seen: set[str] = set()
merged_scopes: list[str] = []
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
merged_scopes.append(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
@@ -171,7 +171,7 @@ class ConnectIntegrationTool(BaseTool):
"title": f"{display_name} Credentials",
"provider": provider,
"types": supported_types,
"scopes": merged_scopes,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}

View File

@@ -53,10 +53,11 @@ class ContinueRunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
review_id: str = "",
**kwargs,
) -> ToolResponseBase:
review_id = review_id.strip() if review_id else ""
review_id = (
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
)
session_id = session.session_id
if not review_id:

View File

@@ -62,12 +62,9 @@ class CreateAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json:
@@ -80,8 +77,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
if library_agent_ids is None:
library_agent_ids = []
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:

View File

@@ -61,12 +61,9 @@ class CustomizeAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json:
@@ -78,8 +75,9 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
if library_agent_ids is None:
library_agent_ids = []
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:

View File

@@ -62,15 +62,10 @@ class EditAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_id: str = "",
agent_json: dict[str, Any] | None = None,
save: bool = True,
library_agent_ids: list[str] | None = None,
**kwargs,
) -> ToolResponseBase:
agent_id = agent_id.strip()
if library_agent_ids is None:
library_agent_ids = []
agent_id = kwargs.get("agent_id", "").strip()
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
@@ -89,6 +84,9 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(

View File

@@ -157,10 +157,9 @@ class SearchFeatureRequestsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
) -> ToolResponseBase:
query = (query or "").strip()
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:
@@ -289,13 +288,11 @@ class CreateFeatureRequestTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
title: str = "",
description: str = "",
existing_issue_id: str | None = None,
**kwargs,
) -> ToolResponseBase:
title = (title or "").strip()
description = (description or "").strip()
title = kwargs.get("title", "").strip()
description = kwargs.get("description", "").strip()
existing_issue_id = kwargs.get("existing_issue_id")
session_id = session.session_id if session else None
if not title or not description:

View File

@@ -34,15 +34,11 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Search marketplace for agents matching the query."""
return await search_agents(
query=query.strip(),
query=kwargs.get("query", "").strip(),
source="marketplace",
session_id=session.session_id,
user_id=user_id,

View File

@@ -86,8 +86,6 @@ class FindBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
include_schemas: bool = False,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
@@ -96,14 +94,14 @@ class FindBlockTool(BaseTool):
user_id: User ID (required)
session: Chat session
query: Search query
include_schemas: Whether to include block schemas in results
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = (query or "").strip()
query = kwargs.get("query", "").strip()
include_schemas = kwargs.get("include_schemas", False)
session_id = session.session_id
if not query:

View File

@@ -41,14 +41,10 @@ class FindLibraryAgentTool(BaseTool):
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=query.strip(),
query=(kwargs.get("query") or "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -51,9 +51,9 @@ class FixAgentGraphTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
agent_json: dict | None = None,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
@@ -98,7 +98,8 @@ class FixAgentGraphTool(BaseTool):
if is_valid:
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es). Agent graph is now valid!"
f"Applied {len(fixes_applied)} fix(es). "
"Agent graph is now valid!"
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,

View File

@@ -60,7 +60,7 @@ class GetAgentBuildingGuideTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:

View File

@@ -68,7 +68,6 @@ class GetDocPageTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
path: str = "",
**kwargs,
) -> ToolResponseBase:
"""Fetch full content of a documentation page.
@@ -82,7 +81,7 @@ class GetDocPageTool(BaseTool):
DocPageResponse: Full document content
ErrorResponse: Error message
"""
path = path.strip()
path = kwargs.get("path", "").strip()
session_id = session.session_id if session else None
if not path:

View File

@@ -56,7 +56,7 @@ class GetMCPGuideTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs, # no tool-specific params; accepts kwargs for forward-compat
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:

View File

@@ -81,7 +81,7 @@ async def execute_block(
node_exec_id: str,
matched_credentials: dict[str, CredentialsMetaInput],
sensitive_action_safe_mode: bool = False,
dry_run: bool,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute a block with full context setup, credential injection, and error handling.
@@ -114,9 +114,11 @@ async def execute_block(
error=sim_error[0],
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
message=(
f"[DRY RUN] Block '{block.name}' simulated successfully "
"— no real execution occurred."
),
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
@@ -335,7 +337,7 @@ async def prepare_block_for_execution(
user_id: str,
session: ChatSession,
session_id: str,
dry_run: bool,
dry_run: bool = False,
) -> "BlockPreparation | ToolResponseBase":
"""Validate and prepare a block for execution.
@@ -535,7 +537,7 @@ async def check_hitl_review(
)
synthetic_node_exec_id = (
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
)
review_context = ExecutionContext(
@@ -580,16 +582,7 @@ def _resolve_discriminated_credentials(
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed.
Handles two discrimination modes:
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
field value selects the provider (e.g. an AI model name -> provider).
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
is ``None``): the discriminator field value (typically a URL) is added to
``discriminator_values`` so that host-scoped credential matching can
compare the credential's host against the target URL.
"""
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
@@ -599,42 +592,25 @@ def _resolve_discriminated_credentials(
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator:
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(field_info.discriminator)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if discriminator_value is not None:
if field_info.discriminator_mapping:
# Provider-based discrimination (e.g. model -> provider)
if discriminator_value in field_info.discriminator_mapping:
effective_field_info = field_info.discriminate(
discriminator_value
)
effective_field_info.discriminator_values.add(
discriminator_value
)
# Model names are safe to log (not PII); URLs are
# intentionally omitted in the host-based branch below.
logger.debug(
"Discriminated provider for %s: %s -> %s",
field_name,
discriminator_value,
effective_field_info.provider,
)
else:
# URL/host-based discrimination (e.g. url -> host matching).
# Deep copy to avoid mutating the cached schema-level
# field_info (model_copy() is shallow — the mutable set
# would be shared).
effective_field_info = field_info.model_copy(deep=True)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
"Added discriminator value for host matching on %s",
field_name,
)
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
"Discriminated provider for %s: %s -> %s",
field_name,
discriminator_value,
effective_field_info.provider,
)
resolved[field_name] = effective_field_info

View File

@@ -102,7 +102,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, BlockOutputResponse)
@@ -133,7 +132,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, ErrorResponse)
@@ -160,7 +158,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(result, BlockOutputResponse)
@@ -197,7 +194,6 @@ class TestExecuteBlockCreditCharging:
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
# Block already executed (with side effects), so output is returned
@@ -281,7 +277,6 @@ async def test_coerce_json_string_to_nested_list():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -322,7 +317,6 @@ async def test_coerce_json_string_to_list():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-2",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -355,7 +349,6 @@ async def test_coerce_json_string_to_dict():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-3",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -389,7 +382,6 @@ async def test_no_coercion_when_type_matches():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-4",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -423,7 +415,6 @@ async def test_coerce_string_to_int():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-5",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -457,7 +448,6 @@ async def test_coerce_skips_none_values():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-6",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -491,7 +481,6 @@ async def test_coerce_union_type_preserves_valid_member():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-7",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -527,7 +516,6 @@ async def test_coerce_inner_elements_of_generic():
session_id=_TEST_SESSION_ID,
node_exec_id="exec-8",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -604,7 +592,6 @@ async def test_prepare_block_not_found() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "not found" in result.message
@@ -625,7 +612,6 @@ async def test_prepare_block_disabled() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "disabled" in result.message
@@ -654,7 +640,6 @@ async def test_prepare_block_unrecognized_fields() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, InputValidationErrorResponse)
assert "unknown_field" in result.unrecognized_fields
@@ -684,7 +669,6 @@ async def test_prepare_block_missing_credentials() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, SetupRequirementsResponse)
@@ -714,7 +698,6 @@ async def test_prepare_block_success_returns_preparation() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, BlockPreparation)
assert result.required_non_credential_keys == {"text"}
@@ -819,7 +802,6 @@ async def test_prepare_block_excluded_by_type() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "cannot be run directly" in result.message
@@ -842,7 +824,6 @@ async def test_prepare_block_excluded_by_id() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "cannot be run directly" in result.message
@@ -876,7 +857,6 @@ async def test_prepare_block_file_ref_expansion_error() -> None:
user_id=_PREP_USER,
session=_make_prep_session(),
session_id=_PREP_SESSION,
dry_run=False,
)
assert isinstance(result, ErrorResponse)
assert "file reference" in result.message.lower()

View File

@@ -1,918 +0,0 @@
"""Tests for credential resolution across all credential types in the CoPilot.
These tests verify that:
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
for URL-based (host-scoped) and provider-based (api_key) credential fields.
2. `find_matching_credential` correctly matches credentials for all types:
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
HostScopedCredentials.
3. The full `resolve_block_credentials` flow correctly resolves matching
credentials or reports them as missing for each credential type.
4. `RunBlockTool._execute` end-to-end tests return correct response types.
"""
from unittest.mock import AsyncMock, patch
from pydantic import SecretStr
from backend.blocks.http import SendAuthenticatedWebRequestBlock
from backend.data.model import (
APIKeyCredentials,
CredentialsFieldInfo,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from ._test_data import make_session
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
from .models import BlockDetailsResponse, SetupRequirementsResponse
from .run_block import RunBlockTool
from .utils import find_matching_credential
_TEST_USER_ID = "test-user-http-cred"
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
# ---------------------------------------------------------------------------
# _resolve_discriminated_credentials tests
# ---------------------------------------------------------------------------
class TestResolveDiscriminatedCredentials:
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
def _get_auth_block(self):
return SendAuthenticatedWebRequestBlock()
def test_url_discriminator_populates_discriminator_values(self):
"""When input_data contains a URL, discriminator_values should include it."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert "https://api.example.com/v1/data" in field_info.discriminator_values
def test_url_discriminator_without_url_keeps_empty_values(self):
"""When no URL is provided, discriminator_values should remain empty."""
block = self._get_auth_block()
input_data = {}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_does_not_mutate_original_field_info(self):
"""The original block schema field_info must not be mutated."""
block = self._get_auth_block()
# Grab a reference to the original schema-level field_info
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
# Call with a URL, which adds to discriminator_values on the copy
_resolve_discriminated_credentials(
block, {"url": "https://api.example.com/v1/data"}
)
# The original object must remain unchanged
assert len(original_info.discriminator_values) == 0
# And a fresh call without URL should also return empty values
result = _resolve_discriminated_credentials(block, {})
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_preserves_provider_and_type(self):
"""Provider and supported_types should be preserved after URL discrimination."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
field_info = result["credentials"]
assert ProviderName.HTTP in field_info.provider
assert "host_scoped" in field_info.supported_types
def test_provider_discriminator_still_works(self):
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
The refactored conditional in _resolve_discriminated_credentials split the
original single ``if`` into nested ``if/else`` branches. This test ensures
the provider-based path still narrows the provider correctly.
"""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
# Should narrow provider to openai
assert ProviderName.OPENAI in field_info.provider
assert "gpt-4o-mini" in field_info.discriminator_values
# ---------------------------------------------------------------------------
# find_matching_credential tests (host-scoped)
# ---------------------------------------------------------------------------
class TestFindMatchingHostScopedCredential:
"""Tests for find_matching_credential with host-scoped credentials."""
def _make_host_scoped_cred(
self, host: str, cred_id: str = "test-cred-id"
) -> HostScopedCredentials:
return HostScopedCredentials(
id=cred_id,
provider="http",
host=host,
headers={"Authorization": SecretStr("Bearer test-token")},
title=f"Cred for {host}",
)
def _make_field_info(
self, discriminator_values: set | None = None
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.HTTP]),
credentials_types=_HOST_SCOPED_TYPES,
credentials_scopes=None,
discriminator="url",
discriminator_values=discriminator_values or set(),
)
def test_matches_credential_for_correct_host(self):
"""A host-scoped credential matching the URL host should be returned."""
cred = self._make_host_scoped_cred("api.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_host(self):
"""A host-scoped credential for a different host should not match."""
cred = self._make_host_scoped_cred("api.github.com")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_any_when_no_discriminator_values(self):
"""With empty discriminator_values, any host-scoped credential matches.
Note: this tests the current fallback behavior in _credential_is_for_host()
where empty discriminator_values means "no host constraint" and any
host-scoped credential is accepted. This is by design for the case where
the target URL is not yet known (e.g. schema preview with empty input).
"""
cred = self._make_host_scoped_cred("api.anything.com")
field_info = self._make_field_info(set())
result = find_matching_credential([cred], field_info)
assert result is not None
def test_wildcard_host_matching(self):
"""Wildcard host (*.example.com) should match subdomains."""
cred = self._make_host_scoped_cred("*.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple host-scoped credentials exist, the correct one is selected."""
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred_github, cred_stripe], field_info)
assert result is not None
assert result.id == "stripe-cred"
# ---------------------------------------------------------------------------
# find_matching_credential tests (api_key)
# ---------------------------------------------------------------------------
class TestFindMatchingAPIKeyCredential:
"""Tests for find_matching_credential with API key credentials."""
def _make_api_key_cred(
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
) -> APIKeyCredentials:
return APIKeyCredentials(
id=cred_id,
provider=provider,
api_key=SecretStr("sk-test-key-123"),
title=f"API key for {provider}",
expires_at=None,
)
def _make_field_info(
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""An API key credential matching the provider should be returned."""
cred = self._make_api_key_cred("google_maps")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An API key credential for a different provider should not match."""
cred = self._make_api_key_cred("openai")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An OAuth2 credential should not match an api_key requirement."""
oauth_cred = OAuth2Credentials(
id="oauth-cred-id",
provider="google_maps",
access_token=SecretStr("mock-token"),
scopes=[],
title="OAuth cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple API key credentials exist, the correct provider is selected."""
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
cred_openai = self._make_api_key_cred("openai", "openai-key")
field_info = self._make_field_info(ProviderName.OPENAI)
result = find_matching_credential([cred_maps, cred_openai], field_info)
assert result is not None
assert result.id == "openai-key"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (oauth2)
# ---------------------------------------------------------------------------
class TestFindMatchingOAuth2Credential:
"""Tests for find_matching_credential with OAuth2 credentials."""
def _make_oauth2_cred(
self,
provider: str = "google",
scopes: list[str] | None = None,
cred_id: str = "test-oauth2-id",
) -> OAuth2Credentials:
return OAuth2Credentials(
id=cred_id,
provider=provider,
access_token=SecretStr("mock-access-token"),
refresh_token=SecretStr("mock-refresh-token"),
access_token_expires_at=1234567890,
scopes=scopes or [],
title=f"OAuth2 cred for {provider}",
)
def _make_field_info(
self,
provider: ProviderName = ProviderName.GOOGLE,
required_scopes: frozenset[str] | None = None,
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=required_scopes,
)
def test_matches_credential_for_correct_provider(self):
"""An OAuth2 credential matching the provider should be returned."""
cred = self._make_oauth2_cred("google")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An OAuth2 credential for a different provider should not match."""
cred = self._make_oauth2_cred("github")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_with_required_scopes(self):
"""An OAuth2 credential with all required scopes should match."""
cred = self._make_oauth2_cred(
"google",
scopes=[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_rejects_credential_with_insufficient_scopes(self):
"""An OAuth2 credential missing required scopes should not match."""
cred = self._make_oauth2_cred(
"google",
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
]
),
)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_when_no_scopes_required(self):
"""An OAuth2 credential should match when no scopes are required."""
cred = self._make_oauth2_cred("google", scopes=[])
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple OAuth2 credentials exist, the correct one is selected."""
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
field_info = self._make_field_info(ProviderName.GITHUB)
result = find_matching_credential([cred_google, cred_github], field_info)
assert result is not None
assert result.id == "github-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (user_password)
# ---------------------------------------------------------------------------
class TestFindMatchingUserPasswordCredential:
"""Tests for find_matching_credential with user/password credentials."""
def _make_user_password_cred(
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
) -> UserPasswordCredentials:
return UserPasswordCredentials(
id=cred_id,
provider=provider,
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title=f"Credentials for {provider}",
)
def _make_field_info(
self, provider: ProviderName = ProviderName.SMTP
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""A user/password credential matching the provider should be returned."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""A user/password credential for a different provider should not match."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An API key credential should not match a user_password requirement."""
api_key_cred = APIKeyCredentials(
id="api-key-cred-id",
provider="smtp",
api_key=SecretStr("wrong-type-key"),
title="API key cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([api_key_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple user/password credentials exist, the correct one is selected."""
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
assert result is not None
assert result.id == "hubspot-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (mixed credential types)
# ---------------------------------------------------------------------------
class TestFindMatchingCredentialMixedTypes:
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
def test_selects_api_key_from_mixed_list(self):
"""API key requirement should skip OAuth2 and user_password credentials."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="openai",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="openai",
username=SecretStr("user"),
password=SecretStr("pass"),
)
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="openai",
api_key=SecretStr("sk-key"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.OPENAI]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[oauth_cred, userpass_cred, api_key_cred], field_info
)
assert result is not None
assert result.id == "apikey-id"
def test_selects_oauth2_from_mixed_list(self):
"""OAuth2 requirement should skip API key and user_password credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="google",
api_key=SecretStr("key"),
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="google",
username=SecretStr("user"),
password=SecretStr("pass"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google",
access_token=SecretStr("token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential(
[api_key_cred, userpass_cred, oauth_cred], field_info
)
assert result is not None
assert result.id == "oauth-id"
def test_selects_user_password_from_mixed_list(self):
"""User/password requirement should skip API key and OAuth2 credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="smtp",
api_key=SecretStr("key"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="smtp",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="smtp",
username=SecretStr("user"),
password=SecretStr("pass"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.SMTP]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[api_key_cred, oauth_cred, userpass_cred], field_info
)
assert result is not None
assert result.id == "userpass-id"
def test_returns_none_when_only_wrong_types_available(self):
"""Should return None when all available creds have the wrong type."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google_maps",
access_token=SecretStr("token"),
scopes=[],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
# ---------------------------------------------------------------------------
# resolve_block_credentials tests (integration — all credential types)
# ---------------------------------------------------------------------------
class TestResolveBlockCredentials:
"""Integration tests for resolve_block_credentials across credential types."""
async def test_matches_host_scoped_credential_for_url(self):
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "matching-cred-id"
assert len(missing) == 0
async def test_reports_missing_when_no_matching_host(self):
"""resolve_block_credentials should report missing creds when host doesn't match."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.stripe.com/v1/charges"}
wrong_host_cred = HostScopedCredentials(
id="wrong-cred-id",
provider="http",
host="api.github.com",
headers={"Authorization": SecretStr("Bearer token")},
title="GitHub API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_host_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_reports_missing_when_no_credentials(self):
"""resolve_block_credentials should report missing when user has no creds at all."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_api_key_credential_for_llm_block(self):
"""resolve_block_credentials should match an API key cred for an LLM block."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
mock_cred = APIKeyCredentials(
id="openai-key-id",
provider="openai",
api_key=SecretStr("sk-test-key"),
title="OpenAI API Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "openai-key-id"
assert len(missing) == 0
async def test_reports_missing_api_key_for_wrong_provider(self):
"""resolve_block_credentials should report missing when API key provider mismatches."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
wrong_provider_cred = APIKeyCredentials(
id="wrong-key-id",
provider="google_maps",
api_key=SecretStr("sk-wrong"),
title="Google Maps Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_provider_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_oauth2_credential_for_google_block(self):
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
from backend.blocks.google.gmail import GmailReadBlock
block = GmailReadBlock()
input_data = {}
mock_cred = OAuth2Credentials(
id="google-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
refresh_token=SecretStr("mock-refresh"),
access_token_expires_at=9999999999,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "google-oauth-id"
assert len(missing) == 0
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
from backend.blocks.google.gmail import GmailSendBlock
block = GmailSendBlock()
input_data = {}
# GmailSendBlock requires gmail.send scope; provide only readonly
insufficient_cred = OAuth2Credentials(
id="limited-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth (limited)",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[insufficient_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_user_password_credential_for_email_block(self):
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
mock_cred = UserPasswordCredentials(
id="smtp-cred-id",
provider="smtp",
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title="SMTP Credentials",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "smtp-cred-id"
assert len(missing) == 0
async def test_reports_missing_user_password_for_wrong_provider(self):
"""resolve_block_credentials should report missing when user/password provider mismatches."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
wrong_cred = UserPasswordCredentials(
id="wrong-cred-id",
provider="dataforseo",
username=SecretStr("user"),
password=SecretStr("pass"),
title="DataForSEO Creds",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
# ---------------------------------------------------------------------------
# RunBlockTool integration tests for authenticated HTTP
# ---------------------------------------------------------------------------
class TestRunBlockToolAuthenticatedHttp:
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
async def test_returns_setup_requirements_when_creds_missing(self):
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={"url": "https://api.example.com/data", "method": "GET"},
dry_run=False,
)
assert isinstance(response, SetupRequirementsResponse)
assert "credentials" in response.message.lower()
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
"""When creds present + required inputs missing -> BlockDetailsResponse.
Note: with input_data={}, no URL is provided so discriminator_values is
empty, meaning _credential_is_for_host() matches any host-scoped
credential vacuously. This test exercises the "creds present + inputs
missing" branch, not host-based matching (which is covered by
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
"""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
tool = RunBlockTool()
# Call with empty input to get schema
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={},
dry_run=False,
)
assert isinstance(response, BlockDetailsResponse)
assert response.block.name == block.name
# The matched credential should be included in the details
assert len(response.block.credentials) > 0
assert response.block.credentials[0].id == "matching-cred-id"

View File

@@ -120,18 +120,14 @@ class CreateFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
name: str = "",
parent_id: str | None = None,
icon: str | None = None,
color: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Create a folder with the given name and optional parent/icon/color."""
assert user_id is not None # guaranteed by requires_auth
name = (name or "").strip()
name = (kwargs.get("name") or "").strip()
parent_id = kwargs.get("parent_id")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not name:
@@ -200,15 +196,12 @@ class ListFoldersTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
parent_id: str | None = None,
include_agents: bool = False,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""List folders as a flat list (by parent) or full tree."""
assert user_id is not None # guaranteed by requires_auth
parent_id = kwargs.get("parent_id")
include_agents = kwargs.get("include_agents", False)
session_id = session.session_id if session else None
try:
@@ -300,18 +293,14 @@ class UpdateFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
name: str | None = None,
icon: str | None = None,
color: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Update a folder's name, icon, or color."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
name = kwargs.get("name")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not folder_id:
@@ -376,16 +365,12 @@ class MoveFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
target_parent_id: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move a folder to a new parent or to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
target_parent_id = kwargs.get("target_parent_id")
session_id = session.session_id if session else None
if not folder_id:
@@ -446,15 +431,11 @@ class DeleteFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
folder_id: str = "",
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Soft-delete a folder; agents inside are moved to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (folder_id or "").strip()
folder_id = (kwargs.get("folder_id") or "").strip()
session_id = session.session_id if session else None
if not folder_id:
@@ -518,17 +499,12 @@ class MoveAgentsToFolderTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
agent_ids: list[str] | None = None,
folder_id: str | None = None,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move one or more agents to a folder or to root level."""
assert user_id is not None # guaranteed by requires_auth
if agent_ids is None:
agent_ids = []
agent_ids = kwargs.get("agent_ids", [])
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_ids:

View File

@@ -71,7 +71,7 @@ class RunAgentInput(BaseModel):
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
dry_run: bool
dry_run: bool = False
@field_validator(
"username_agent_slug",
@@ -153,10 +153,14 @@ class RunAgentTool(BaseTool):
},
"dry_run": {
"type": "boolean",
"description": "Execute in preview mode.",
"description": (
"When true, simulates the entire agent execution using an LLM "
"for each block — no real API calls, no credentials needed, "
"no credits charged. Useful for testing agent wiring end-to-end."
),
},
},
"required": ["dry_run"],
"required": [],
}
@property
@@ -170,16 +174,8 @@ class RunAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the tool with automatic state detection.
Note: This tool accepts **kwargs and delegates to RunAgentInput for
validation because the parameter set is complex with cross-field
validators defined in the Pydantic model.
"""
"""Execute the tool with automatic state detection."""
params = RunAgentInput(**kwargs)
# Session-level dry_run forces all tool calls to use dry-run mode.
if session.dry_run:
params.dry_run = True
session_id = session.session_id
# Validate at least one identifier is provided
@@ -205,18 +201,6 @@ class RunAgentTool(BaseTool):
# Determine if this is a schedule request
is_schedule = bool(params.schedule_name or params.cron)
# Session-level dry-run blocks scheduling — schedules create real
# side effects that cannot be simulated.
if params.dry_run and is_schedule:
return ErrorResponse(
message=(
"Scheduling is disabled in dry-run mode because it creates "
"real side effects. Remove cron/schedule_name to simulate "
"a run, or disable dry-run to create a real schedule."
),
session_id=session_id,
)
try:
# Step 1: Fetch agent details
graph: GraphModel | None = None
@@ -474,8 +458,8 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
dry_run: bool,
wait_for_result: int = 0,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
session_id = session.session_id

View File

@@ -53,7 +53,6 @@ async def test_run_agent(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"test_input": "Hello World"},
dry_run=False,
session=session,
)
@@ -94,7 +93,6 @@ async def test_run_agent_missing_inputs(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={}, # Missing required input
dry_run=False,
session=session,
)
@@ -127,7 +125,6 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug="invalid/agent-id",
inputs={"test_input": "Hello World"},
dry_run=False,
session=session,
)
@@ -168,7 +165,6 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"user_prompt": "What is 2+2?"},
dry_run=False,
session=session,
)
@@ -207,7 +203,6 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
username_agent_slug=agent_marketplace_id,
inputs={},
use_defaults=False,
dry_run=False,
session=session,
)
@@ -243,7 +238,6 @@ async def test_run_agent_with_use_defaults(setup_test_data):
username_agent_slug=agent_marketplace_id,
inputs={},
use_defaults=True,
dry_run=False,
session=session,
)
@@ -274,7 +268,6 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"url": "https://example.com"},
dry_run=False,
session=session,
)
@@ -307,7 +300,6 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug="no-slash-here",
inputs={},
dry_run=False,
session=session,
)
@@ -335,7 +327,6 @@ async def test_run_agent_unauthenticated():
tool_call_id=str(uuid.uuid4()),
username_agent_slug="test/test-agent",
inputs={},
dry_run=False,
session=session,
)
@@ -368,7 +359,6 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
inputs={"test_input": "test"},
schedule_name="My Schedule",
cron="", # Empty cron
dry_run=False,
session=session,
)
@@ -401,7 +391,6 @@ async def test_run_agent_schedule_without_name(setup_test_data):
inputs={"test_input": "test"},
schedule_name="", # Empty name
cron="0 9 * * *",
dry_run=False,
session=session,
)
@@ -435,7 +424,6 @@ async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"unknown_field": "some value",
"another_unknown": "another value",
},
dry_run=False,
session=session,
)

View File

@@ -51,10 +51,14 @@ class RunBlockTool(BaseTool):
},
"dry_run": {
"type": "boolean",
"description": "Execute in preview mode.",
"description": (
"When true, simulates block execution using an LLM without making any "
"real API calls or producing side effects. Useful for testing agent "
"wiring and previewing outputs. Default: false."
),
},
},
"required": ["block_id", "input_data", "dry_run"],
"required": ["block_id", "input_data"],
}
@property
@@ -65,10 +69,6 @@ class RunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
*,
block_id: str = "",
input_data: dict | None = None,
dry_run: bool,
**kwargs,
) -> ToolResponseBase:
"""Execute a block with the given input data.
@@ -78,19 +78,15 @@ class RunBlockTool(BaseTool):
session: Chat session
block_id: Block UUID to execute
input_data: Input values for the block
dry_run: If True, simulate execution without side effects
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
block_id = block_id.strip()
if input_data is None:
input_data = {}
# Session-level dry_run forces all tool calls to use dry-run mode.
if session.dry_run:
dry_run = True
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
dry_run = bool(kwargs.get("dry_run", False))
session_id = session.session_id
if not block_id:

View File

@@ -103,7 +103,6 @@ class TestRunBlockFiltering:
session=session,
block_id="input-block-id",
input_data={},
dry_run=False,
)
assert isinstance(response, ErrorResponse)
@@ -130,7 +129,6 @@ class TestRunBlockFiltering:
session=session,
block_id=orchestrator_id,
input_data={},
dry_run=False,
)
assert isinstance(response, ErrorResponse)
@@ -156,7 +154,6 @@ class TestRunBlockFiltering:
session=session,
block_id=block_id,
input_data={},
dry_run=False,
)
finally:
_current_permissions.reset(token)
@@ -190,7 +187,6 @@ class TestRunBlockFiltering:
session=session,
block_id=block_id,
input_data={},
dry_run=False,
)
finally:
_current_permissions.reset(token)
@@ -226,7 +222,6 @@ class TestRunBlockFiltering:
session=session,
block_id="standard-id",
input_data={},
dry_run=False,
)
# Should NOT be an ErrorResponse about CoPilot exclusion
@@ -287,7 +282,6 @@ class TestRunBlockInputValidation:
"prompt": "Write a haiku about coding",
"LLM_Model": "claude-opus-4-6",
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -333,7 +327,6 @@ class TestRunBlockInputValidation:
"system_prompt": "Be helpful",
"retries": 5,
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -377,7 +370,6 @@ class TestRunBlockInputValidation:
input_data={
"LLM_Model": "claude-opus-4-6",
},
dry_run=False,
)
assert isinstance(response, InputValidationErrorResponse)
@@ -432,7 +424,6 @@ class TestRunBlockInputValidation:
"prompt": "Write a haiku",
"model": "gpt-4o-mini",
},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -472,7 +463,6 @@ class TestRunBlockInputValidation:
input_data={
"model": "gpt-4o-mini",
},
dry_run=False,
)
assert isinstance(response, BlockDetailsResponse)
@@ -524,7 +514,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="delete-branch-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, ReviewRequiredResponse)
@@ -585,7 +574,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="delete-branch-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
@@ -640,7 +628,6 @@ class TestRunBlockSensitiveAction:
session=session,
block_id="http-request-id",
input_data=input_data,
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)

View File

@@ -91,40 +91,21 @@ class RunMCPToolTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
server_url: str = "",
tool_name: str = "",
tool_arguments: dict[str, Any] | None = None,
**kwargs,
) -> ToolResponseBase:
server_url = server_url.strip()
tool_name = tool_name.strip()
server_url: str = (kwargs.get("server_url") or "").strip()
tool_name: str = (kwargs.get("tool_name") or "").strip()
raw_tool_arguments = kwargs.get("tool_arguments")
tool_arguments: dict[str, Any] = (
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
)
session_id = session.session_id
# Session-level dry_run prevents real MCP tool execution.
# Discovery (no tool_name) is still allowed so the agent can inspect
# available tools, but actual execution is blocked.
if session.dry_run and tool_name:
return MCPToolOutputResponse(
message=(
f"[dry-run] MCP tool '{tool_name}' on "
f"{server_host(server_url)} was not executed "
"because the session is in dry-run mode."
),
server_url=server_url,
tool_name=tool_name,
result=None,
success=True,
session_id=session_id,
)
if tool_arguments is not None and not isinstance(tool_arguments, dict):
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
return ErrorResponse(
message="tool_arguments must be a JSON object.",
session_id=session_id,
)
resolved_tool_arguments: dict[str, Any] = (
tool_arguments if isinstance(tool_arguments, dict) else {}
)
if not server_url:
return ErrorResponse(
@@ -186,7 +167,7 @@ class RunMCPToolTool(BaseTool):
else:
# Stage 2: Execute the selected tool
return await self._execute_tool(
client, server_url, tool_name, resolved_tool_arguments, session_id
client, server_url, tool_name, tool_arguments, session_id
)
except HTTPClientError as e:

View File

@@ -85,7 +85,6 @@ class SearchDocsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
query: str = "",
**kwargs,
) -> ToolResponseBase:
"""Search documentation and return relevant sections.
@@ -100,7 +99,7 @@ class SearchDocsTool(BaseTool):
NoResultsResponse: No results found
ErrorResponse: Error message
"""
query = query.strip()
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:

View File

@@ -73,10 +73,7 @@ def make_openai_response(
@pytest.mark.asyncio
async def test_simulate_block_basic():
"""simulate_block returns correct (output_name, output_data) tuples.
Empty "error" pins are dropped at source — only non-empty errors are yielded.
"""
"""simulate_block returns correct (output_name, output_data) tuples."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
@@ -91,8 +88,7 @@ async def test_simulate_block_basic():
outputs.append((name, data))
assert ("result", "simulated output") in outputs
# Empty error pin is dropped at the simulator level
assert ("error", "") not in outputs
assert ("error", "") in outputs
@pytest.mark.asyncio
@@ -117,8 +113,6 @@ async def test_simulate_block_json_retry():
assert mock_client.chat.completions.create.call_count == 3
assert ("result", "ok") in outputs
# Empty error pin is dropped
assert ("error", "") not in outputs
@pytest.mark.asyncio
@@ -147,7 +141,7 @@ async def test_simulate_block_all_retries_exhausted():
@pytest.mark.asyncio
async def test_simulate_block_missing_output_pins():
"""LLM response missing some output pins; verify non-error pins filled with None."""
"""LLM response missing some output pins; verify they're filled with None."""
mock_block = make_mock_block(
output_props={
"result": {"type": "string"},
@@ -170,29 +164,7 @@ async def test_simulate_block_missing_output_pins():
assert outputs["result"] == "hello"
assert outputs["count"] is None # missing pin filled with None
assert "error" not in outputs # missing error pin is omitted entirely
@pytest.mark.asyncio
async def test_simulate_block_keeps_nonempty_error():
"""simulate_block keeps non-empty error pins (simulated logical errors)."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
return_value=make_openai_response(
'{"result": "", "error": "API rate limit exceeded"}'
)
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = []
async for name, data in simulate_block(mock_block, {"query": "test"}):
outputs.append((name, data))
assert ("result", "") in outputs
assert ("error", "API rate limit exceeded") in outputs
assert outputs["error"] == "" # "error" pin filled with ""
@pytest.mark.asyncio
@@ -228,19 +200,6 @@ async def test_simulate_block_truncates_long_inputs():
assert len(parsed["text"]) < 25000
def test_build_simulation_prompt_excludes_error_from_must_include():
"""The 'MUST include' prompt line should NOT list 'error' — the prompt
already instructs the LLM to OMIT error unless simulating a logical error.
Including it in 'MUST include' would be contradictory."""
block = make_mock_block() # default output_props has "result" and "error"
system_prompt, _ = build_simulation_prompt(block, {"query": "test"})
must_include_line = [
line for line in system_prompt.splitlines() if "MUST include" in line
][0]
assert '"result"' in must_include_line
assert '"error"' not in must_include_line
# ---------------------------------------------------------------------------
# execute_block dry-run tests
# ---------------------------------------------------------------------------
@@ -279,7 +238,7 @@ async def test_execute_block_dry_run_skips_real_execution():
@pytest.mark.asyncio
async def test_execute_block_dry_run_response_format():
"""Dry-run response should match real execution message format and have success=True."""
"""Dry-run response should contain [DRY RUN] in message and success=True."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
@@ -300,8 +259,7 @@ async def test_execute_block_dry_run_response_format():
)
assert isinstance(response, BlockOutputResponse)
assert "executed successfully" in response.message
assert "[DRY RUN]" not in response.message # must not leak to LLM context
assert "[DRY RUN]" in response.message
assert response.success is True
assert response.outputs == {"result": ["simulated"]}
@@ -349,24 +307,23 @@ async def test_execute_block_real_execution_unchanged():
def test_run_block_tool_dry_run_param():
"""RunBlockTool parameters should include 'dry_run' as a required field."""
"""RunBlockTool parameters should include 'dry_run'."""
tool = RunBlockTool()
params = tool.parameters
assert "dry_run" in params["properties"]
assert params["properties"]["dry_run"]["type"] == "boolean"
assert "dry_run" in params["required"]
def test_run_block_tool_dry_run_calls_execute():
"""RunBlockTool._execute accepts dry_run as a typed parameter.
"""RunBlockTool._execute extracts dry_run from kwargs correctly.
We verify the parameter exists in the signature and is forwarded to
execute_block.
We verify the extraction logic directly by inspecting the source, then confirm
the kwarg is forwarded in the execute_block call site.
"""
source = inspect.getsource(run_block_module.RunBlockTool._execute)
# Verify dry_run is a typed parameter (not extracted from kwargs)
# Verify dry_run is extracted from kwargs
assert "dry_run" in source
assert "dry_run: bool" in source
assert 'kwargs.get("dry_run"' in source
# Scope to _execute method source only — module-wide search is brittle
# and can match unrelated text/comments.
@@ -375,107 +332,13 @@ def test_run_block_tool_dry_run_calls_execute():
assert "dry_run=dry_run" in source_execute
@pytest.mark.asyncio
async def test_execute_block_dry_run_no_empty_error_from_simulator():
"""The simulator no longer yields empty error pins, so execute_block
simply passes through whatever the simulator produces.
Since the fix is at the simulator level, even if a simulator somehow
yields only non-error outputs, they pass through unchanged.
"""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
# Simulator now omits empty error pins at source
yield "result", "simulated output"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert response.is_dry_run is True
assert "error" not in response.outputs
assert response.outputs == {"result": ["simulated output"]}
@pytest.mark.asyncio
async def test_execute_block_dry_run_keeps_nonempty_error_pin():
"""Dry-run should keep the 'error' pin when it contains a real error message."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
yield "result", ""
yield "error", "API rate limit exceeded"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
# Non-empty error should be preserved
assert "error" in response.outputs
assert response.outputs["error"] == ["API rate limit exceeded"]
@pytest.mark.asyncio
async def test_execute_block_dry_run_message_includes_completed_status():
"""Dry-run message should clearly indicate COMPLETED status."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
yield "result", "simulated"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert "executed successfully" in response.message
@pytest.mark.asyncio
async def test_execute_block_dry_run_simulator_error_returns_error_response():
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
mock_block = make_mock_block()
async def fake_simulate_error(block, input_data):
yield (
"error",
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",
)
yield "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key)."
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate_error

View File

@@ -76,7 +76,6 @@ async def test_run_block_returns_details_when_no_input_provided():
session=session,
block_id="http-block-id",
input_data={}, # Empty input data
dry_run=False,
)
# Should return BlockDetailsResponse showing the schema
@@ -144,7 +143,6 @@ async def test_run_block_returns_details_when_only_credentials_provided():
session=session,
block_id="api-block-id",
input_data={"credentials": {"some": "cred"}}, # Only credential
dry_run=False,
)
# Should return details because no non-credential inputs provided

View File

@@ -151,7 +151,7 @@ async def test_non_dict_tool_arguments_returns_error():
session=session,
server_url=_SERVER_URL,
tool_name="fetch",
tool_arguments=["this", "is", "a", "list"], # type: ignore[arg-type] # intentionally wrong type to test validation
tool_arguments=["this", "is", "a", "list"], # wrong type
)
assert isinstance(response, ErrorResponse)

View File

@@ -1,499 +0,0 @@
"""Tests for session-level dry_run flag propagation.
Verifies that when a session has dry_run=True, run_block, run_agent, and
run_mcp_tool calls are forced to use dry-run mode, regardless of what the
individual tool call specifies. The single source of truth is
``session.dry_run``.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import ErrorResponse, MCPToolOutputResponse
from backend.copilot.tools.run_agent import RunAgentInput, RunAgentTool
from backend.copilot.tools.run_block import RunBlockTool
from backend.copilot.tools.run_mcp_tool import RunMCPToolTool
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(dry_run: bool = False) -> ChatSession:
"""Create a minimal ChatSession for testing."""
session = ChatSession.new("test-user", dry_run=dry_run)
return session
def _make_mock_block(name: str = "TestBlock"):
"""Create a minimal mock block with jsonschema() methods."""
block = MagicMock()
block.name = name
block.description = "A test block"
block.disabled = False
block.block_type = "STANDARD"
block.id = "test-block-id"
block.input_schema = MagicMock()
block.input_schema.jsonschema.return_value = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
block.input_schema.get_credentials_fields.return_value = {}
block.input_schema.get_credentials_fields_info.return_value = {}
block.output_schema = MagicMock()
block.output_schema.jsonschema.return_value = {
"type": "object",
"properties": {"result": {"type": "string"}},
"required": ["result"],
}
return block
# ---------------------------------------------------------------------------
# RunBlockTool tests
# ---------------------------------------------------------------------------
class TestRunBlockToolSessionDryRun:
"""Test that RunBlockTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_forces_block_dry_run(self):
"""When session dry_run is True, run_block should force dry_run=True."""
tool = RunBlockTool()
session = _make_session(dry_run=True)
mock_block = _make_mock_block()
with (
patch(
"backend.copilot.tools.run_block.prepare_block_for_execution"
) as mock_prep,
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
patch(
"backend.copilot.tools.run_block.get_current_permissions",
return_value=None,
),
):
# Set up prepare_block_for_execution to return a mock prep
mock_prep_result = MagicMock()
mock_prep_result.block = mock_block
mock_prep_result.input_data = {"query": "test"}
mock_prep_result.matched_credentials = {}
mock_prep_result.synthetic_node_id = "node-1"
mock_prep.return_value = mock_prep_result
# Set up execute_block to return a success
mock_exec.return_value = MagicMock(
message="Block 'TestBlock' executed successfully",
success=True,
)
await tool._execute(
user_id="test-user",
session=session,
block_id="test-block-id",
input_data={"query": "test"},
dry_run=False, # User passed False, but session overrides
)
# Verify execute_block was called with dry_run=True
mock_exec.assert_called_once()
call_kwargs = mock_exec.call_args
assert call_kwargs.kwargs.get("dry_run") is True
@pytest.mark.asyncio
async def test_no_session_dry_run_respects_tool_param(self):
"""When session dry_run is False, tool-level dry_run should be respected."""
tool = RunBlockTool()
session = _make_session(dry_run=False)
mock_block = _make_mock_block()
with (
patch(
"backend.copilot.tools.run_block.prepare_block_for_execution"
) as mock_prep,
patch("backend.copilot.tools.run_block.execute_block") as mock_exec,
patch(
"backend.copilot.tools.run_block.get_current_permissions",
return_value=None,
),
patch("backend.copilot.tools.run_block.check_hitl_review") as mock_hitl,
):
mock_prep_result = MagicMock()
mock_prep_result.block = mock_block
mock_prep_result.input_data = {"query": "test"}
mock_prep_result.matched_credentials = {}
mock_prep_result.synthetic_node_id = "node-1"
mock_prep_result.required_non_credential_keys = {"query"}
mock_prep_result.provided_input_keys = {"query"}
mock_prep.return_value = mock_prep_result
mock_hitl.return_value = ("node-exec-1", {"query": "test"})
mock_exec.return_value = MagicMock(
message="Block executed",
success=True,
)
await tool._execute(
user_id="test-user",
session=session,
block_id="test-block-id",
input_data={"query": "test"},
dry_run=False,
)
# Verify execute_block was called with dry_run=False
mock_exec.assert_called_once()
call_kwargs = mock_exec.call_args
assert call_kwargs.kwargs.get("dry_run") is False
# ---------------------------------------------------------------------------
# RunAgentTool tests
# ---------------------------------------------------------------------------
class TestRunAgentToolSessionDryRun:
"""Test that RunAgentTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_forces_agent_dry_run(self):
"""When session dry_run is True, run_agent params.dry_run should be forced True."""
tool = RunAgentTool()
session = _make_session(dry_run=True)
# Mock the graph and dependencies
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.name = "Test Agent"
mock_graph.description = "A test agent"
mock_graph.input_schema = {"properties": {}, "required": []}
mock_graph.trigger_setup_info = None
mock_library_agent = MagicMock()
mock_library_agent.id = "lib-1"
mock_library_agent.graph_id = "graph-1"
mock_library_agent.graph_version = 1
mock_library_agent.name = "Test Agent"
mock_execution = MagicMock()
mock_execution.id = "exec-1"
with (
patch("backend.copilot.tools.run_agent.graph_db"),
patch("backend.copilot.tools.run_agent.library_db"),
patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
return_value=(mock_graph, None),
),
patch(
"backend.copilot.tools.run_agent.match_user_credentials_to_graph",
return_value=({}, []),
),
patch(
"backend.copilot.tools.run_agent.get_or_create_library_agent",
return_value=mock_library_agent,
),
patch("backend.copilot.tools.run_agent.execution_utils") as mock_exec_utils,
patch("backend.copilot.tools.run_agent.track_agent_run_success"),
):
mock_exec_utils.add_graph_execution = AsyncMock(return_value=mock_execution)
await tool._execute(
user_id="test-user",
session=session,
username_agent_slug="user/test-agent",
dry_run=False, # User passed False, but session overrides
use_defaults=True,
)
# Verify add_graph_execution was called with dry_run=True
mock_exec_utils.add_graph_execution.assert_called_once()
call_kwargs = mock_exec_utils.add_graph_execution.call_args
assert call_kwargs.kwargs.get("dry_run") is True
@pytest.mark.asyncio
async def test_session_dry_run_blocks_scheduling(self):
"""When session dry_run is True, scheduling requests should be rejected."""
tool = RunAgentTool()
session = _make_session(dry_run=True)
result = await tool._execute(
user_id="test-user",
session=session,
username_agent_slug="user/test-agent",
schedule_name="daily-run",
cron="0 9 * * *",
dry_run=False, # Session overrides to True
)
assert isinstance(result, ErrorResponse)
assert "dry-run" in result.message.lower()
assert (
"scheduling" in result.message.lower()
or "schedule" in result.message.lower()
)
# ---------------------------------------------------------------------------
# ChatSession model tests
# ---------------------------------------------------------------------------
class TestChatSessionDryRun:
"""Test the dry_run field on ChatSession model."""
def test_new_session_default_dry_run_false(self):
session = ChatSession.new("test-user", dry_run=False)
assert session.dry_run is False
def test_new_session_dry_run_true(self):
session = ChatSession.new("test-user", dry_run=True)
assert session.dry_run is True
def test_new_session_dry_run_false_explicit(self):
session = ChatSession.new("test-user", dry_run=False)
assert session.dry_run is False
# ---------------------------------------------------------------------------
# RunAgentInput tests
# ---------------------------------------------------------------------------
class TestRunAgentInputDryRunOverride:
"""Test that RunAgentInput.dry_run can be mutated by session-level override."""
def test_explicit_dry_run_false(self):
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
assert params.dry_run is False
def test_session_override(self):
params = RunAgentInput(username_agent_slug="user/agent", dry_run=False)
# Simulate session-level override
params.dry_run = True
assert params.dry_run is True
# ---------------------------------------------------------------------------
# RunMCPToolTool tests
# ---------------------------------------------------------------------------
class TestRunMCPToolToolSessionDryRun:
"""Test that RunMCPToolTool respects session-level dry_run."""
@pytest.mark.asyncio
async def test_session_dry_run_blocks_mcp_execution(self):
"""When session dry_run is True, MCP tool execution should be skipped."""
tool = RunMCPToolTool()
session = _make_session(dry_run=True)
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="some_tool",
tool_arguments={"key": "value"},
)
assert isinstance(result, MCPToolOutputResponse)
assert result.success is True
assert "dry-run" in result.message
assert result.tool_name == "some_tool"
assert result.result is None
@pytest.mark.asyncio
async def test_session_dry_run_allows_discovery(self):
"""When session dry_run is True, tool discovery (no tool_name) should still work."""
tool = RunMCPToolTool()
session = _make_session(dry_run=True)
# Discovery requires a network call, so we mock the client
with (
patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
return_value=None,
),
patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
return_value=None,
),
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client_cls.return_value = mock_client
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_client.list_tools.return_value = [mock_tool]
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="", # Discovery mode
)
# Discovery should proceed normally
mock_client.initialize.assert_called_once()
mock_client.list_tools.assert_called_once()
assert "Discovered" in result.message
@pytest.mark.asyncio
async def test_no_session_dry_run_allows_execution(self):
"""When session dry_run is False, MCP tool execution should proceed."""
tool = RunMCPToolTool()
session = _make_session(dry_run=False)
with (
patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
return_value=None,
),
patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
return_value=None,
),
patch("backend.copilot.tools.run_mcp_tool.MCPClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client_cls.return_value = mock_client
mock_result = MagicMock()
mock_result.is_error = False
mock_result.content = [{"type": "text", "text": "hello"}]
mock_client.call_tool.return_value = mock_result
result = await tool._execute(
user_id="test-user",
session=session,
server_url="https://mcp.example.com/sse",
tool_name="some_tool",
tool_arguments={"key": "value"},
)
# Execution should proceed
mock_client.initialize.assert_called_once()
mock_client.call_tool.assert_called_once_with("some_tool", {"key": "value"})
assert isinstance(result, MCPToolOutputResponse)
assert result.success is True
# ---------------------------------------------------------------------------
# Backward-compatibility tests for ChatSessionMetadata deserialization
# ---------------------------------------------------------------------------
class TestChatSessionMetadataBackwardCompat:
"""Verify that sessions created before the dry_run field existed still load.
The ``metadata`` JSON column in the DB may contain ``{}``, ``null``, or a
dict without the ``dry_run`` key for sessions created before the flag was
introduced. These must deserialize without errors and default to
``dry_run=False``.
"""
def test_metadata_default_construction(self):
"""ChatSessionMetadata() with no args should default dry_run=False."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata()
assert meta.dry_run is False
def test_metadata_from_empty_dict(self):
"""Deserializing an empty dict (old-format metadata) should succeed."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata.model_validate({})
assert meta.dry_run is False
def test_metadata_from_dict_without_dry_run_key(self):
"""A metadata dict with other keys but no dry_run should still work."""
from backend.copilot.model import ChatSessionMetadata
meta = ChatSessionMetadata.model_validate({"some_future_field": 42})
# dry_run should fall back to default
assert meta.dry_run is False
def test_metadata_round_trip_with_dry_run_false(self):
"""Serialize then deserialize with dry_run=False."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=False)
raw = original.model_dump()
restored = ChatSessionMetadata.model_validate(raw)
assert restored.dry_run is False
def test_metadata_round_trip_with_dry_run_true(self):
"""Serialize then deserialize with dry_run=True."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=True)
raw = original.model_dump()
restored = ChatSessionMetadata.model_validate(raw)
assert restored.dry_run is True
def test_metadata_json_round_trip(self):
"""Serialize to JSON string and back, simulating Redis cache flow."""
from backend.copilot.model import ChatSessionMetadata
original = ChatSessionMetadata(dry_run=True)
json_str = original.model_dump_json()
restored = ChatSessionMetadata.model_validate_json(json_str)
assert restored.dry_run is True
def test_session_dry_run_property_with_default_metadata(self):
"""ChatSession.dry_run returns False when metadata has no dry_run."""
from backend.copilot.model import ChatSessionMetadata
# Simulate building a session with metadata deserialized from an old row
meta = ChatSessionMetadata.model_validate({})
session = _make_session(dry_run=False)
session.metadata = meta
assert session.dry_run is False
def test_session_info_dry_run_property_with_default_metadata(self):
"""ChatSessionInfo.dry_run returns False when metadata is default."""
from datetime import UTC, datetime
from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata
info = ChatSessionInfo(
session_id="old-session-id",
user_id="test-user",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata.model_validate({}),
)
assert info.dry_run is False
def test_session_full_json_round_trip_without_dry_run(self):
"""A full ChatSession JSON round-trip preserves dry_run default."""
session = _make_session(dry_run=False)
json_bytes = session.model_dump_json()
restored = ChatSession.model_validate_json(json_bytes)
assert restored.dry_run is False
assert restored.metadata.dry_run is False
def test_session_full_json_round_trip_with_dry_run(self):
"""A full ChatSession JSON round-trip preserves dry_run=True."""
session = _make_session(dry_run=True)
json_bytes = session.model_dump_json()
restored = ChatSession.model_validate_json(json_bytes)
assert restored.dry_run is True
assert restored.metadata.dry_run is True

View File

@@ -121,7 +121,7 @@ def _serialize_missing_credential(
provider = next(iter(field_info.provider), "unknown")
scopes = sorted(field_info.required_scopes or [])
result: dict[str, Any] = {
return {
"id": field_key,
"title": field_key.replace("_", " ").title(),
"provider": provider,
@@ -131,17 +131,6 @@ def _serialize_missing_credential(
"scopes": scopes,
}
# Include discriminator info so the frontend can auto-match
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
if field_info.discriminator:
result["discriminator"] = field_info.discriminator
if field_info.discriminator_values:
result["discriminator_values"] = sorted(
str(v) for v in field_info.discriminator_values
)
return result
def build_missing_credentials_from_graph(
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None

Some files were not shown because too many files have changed in this diff Show More