Merge branch 'dev' into fix/open-2895-sheets-missing-credentials

This commit is contained in:
Nicholas Tindle
2026-03-27 15:00:45 -05:00
committed by GitHub
41 changed files with 2930 additions and 454 deletions

View File

@@ -0,0 +1,106 @@
---
name: open-pr
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
user-invocable: true
args: "[base-branch] — optional target branch (defaults to dev)."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Open a Pull Request
## Step 1: Pre-flight checks
Before opening the PR:
1. Ensure all changes are committed
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
## Step 2: Test coverage
**This is critical.** Before opening the PR, verify:
### Existing behavior is not broken
- Identify which modules/components your changes touch
- Run the existing test suites for those areas
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
### New behavior has test coverage
- Every new feature, endpoint, or behavior change needs tests
- If you added a new block, add tests for that block
- If you changed API behavior, add or update API tests
- If you changed frontend behavior, verify it doesn't break existing flows
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
## Step 3: Create the PR using the repo template
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
2. Preserve the exact section titles and formatting, including:
- `### Why / What / How`
- `### Changes 🏗️`
- `### Checklist 📋`
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
5. Do not alter the template structure, rename sections, or remove any checklist items
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
```bash
BASE_BRANCH="${BASE_BRANCH:-dev}"
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
PREOF
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
rm "$PR_BODY"
```
## Step 4: Review workflow
### If you have a workspace that allows testing (docker, running backend, etc.)
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
### If you do NOT have a workspace that allows testing
This is common for agents running in worktrees without a full stack. In this case:
1. Run `/pr-review` locally to catch obvious issues before pushing
2. **Comment `/review` on the PR** after creating it to trigger the review bot
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
4. Do NOT proceed or merge until the bot review comes back
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
```bash
# After creating the PR:
PR_NUMBER=$(gh pr view --json number -q .number)
gh pr comment "$PR_NUMBER" --body "/review"
# Then use /pr-address to poll for and address the review when it arrives
```
## Step 5: Address review feedback
Once the review bot or human reviewers leave comments:
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
- Do not merge without human approval.
## Related skills
| Skill | When to use |
|---|---|
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
## Step 6: Post-creation
After the PR is created and review is triggered:
- Share the PR URL with the user
- If waiting on the review bot, let the user know the expected wait time (~30 min)
- Do not merge without human approval

View File

@@ -0,0 +1,195 @@
---
name: setup-repo
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
user-invocable: true
args: "No arguments — interactive setup via prompts."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Repository Setup
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
- A **main** worktree (the primary checkout)
- A **reviews** worktree (for PR reviews)
- **N work branches** (branch1..branchN) for parallel development
## Step 1: Identify the repo
Determine the repo root and parent directory:
```bash
ROOT=$(git rev-parse --show-toplevel)
REPO_NAME=$(basename "$ROOT")
PARENT=$(dirname "$ROOT")
```
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
```bash
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
if [ "$SIBLING_COUNT" -gt 1 ]; then
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
# Use $ROOT as-is; skip renaming/restructuring
else
echo "INFO: Fresh clone detected, proceeding with setup"
fi
```
## Step 2: Ask the user questions
Use AskUserQuestion to gather setup preferences:
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
- These become `branch1` through `branchN`
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
- All work branches and reviews will start from this
## Step 3: Fetch and set up branches
```bash
cd "$ROOT"
git fetch origin
# Create the reviews branch from base (skip if already exists)
if git show-ref --verify --quiet refs/heads/reviews; then
echo "INFO: Branch 'reviews' already exists, skipping"
else
git branch reviews <base-branch>
fi
# Create numbered work branches from base (skip if already exists)
for i in $(seq 1 "$COUNT"); do
if git show-ref --verify --quiet "refs/heads/branch$i"; then
echo "INFO: Branch 'branch$i' already exists, skipping"
else
git branch "branch$i" <base-branch>
fi
done
```
## Step 4: Create worktrees
Create worktrees as siblings to the main checkout:
```bash
if [ -d "$PARENT/reviews" ]; then
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
else
git worktree add "$PARENT/reviews" reviews
fi
for i in $(seq 1 "$COUNT"); do
if [ -d "$PARENT/branch$i" ]; then
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
else
git worktree add "$PARENT/branch$i" "branch$i"
fi
done
```
## Step 5: Set up environment files
**Do NOT assume .env files exist.** For each worktree (including main if needed):
1. Check if `.env` exists in the source worktree for each path
2. If `.env` exists, copy it
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
4. If neither exists, warn the user and list which env files are missing
Env file locations to check (same as the `/worktree` skill — keep these in sync):
- `autogpt_platform/.env`
- `autogpt_platform/backend/.env`
- `autogpt_platform/frontend/.env`
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
```bash
SOURCE="$ROOT"
WORKTREES="reviews"
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
FOUND_ANY_ENV=0
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
if [ -f "$SOURCE/$envpath/.env" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
else
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
fi
done
done
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
echo "WARNING: No environment files or templates were found in the source worktree."
# Use AskUserQuestion to confirm: "Continue setup without env files?"
# If the user declines, stop here and let them set up .env files first.
fi
```
## Step 6: Copy branchlet config
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
```bash
if [ -f "$ROOT/.branchlet.json" ]; then
for wt in $WORKTREES; do
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
done
fi
```
## Step 7: Install dependencies
Install deps in all worktrees. Run these sequentially per worktree:
```bash
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
echo "=== Installing deps for $wt ==="
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
echo "=== Done: $wt ===" ||
echo "=== FAILED: $wt ==="
done
```
This is slow. Run in background if possible and notify when complete.
## Step 8: Verify and report
After setup, verify and report to the user:
```bash
git worktree list
```
Summarize:
- Number of worktrees created
- Which env files were copied vs created from defaults vs missing
- Any warnings or errors encountered
## Final directory layout
```
parent/
main/ # Primary checkout (already exists)
reviews/ # PR review worktree
branch1/ # Work branch 1
branch2/ # Work branch 2
...
branchN/ # Work branch N
```

View File

@@ -1,6 +1,7 @@
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
import logging
from typing import Optional
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, HTTPException, Security
@@ -12,6 +13,7 @@ from backend.copilot.rate_limit import (
get_usage_status,
reset_user_usage,
)
from backend.data.user import get_user_by_email, get_user_email_by_id
logger = logging.getLogger(__name__)
@@ -26,31 +28,72 @@ router = APIRouter(
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
async def _resolve_user_id(
user_id: Optional[str], email: Optional[str]
) -> tuple[str, Optional[str]]:
"""Resolve a user_id and email from the provided parameters.
Returns (user_id, email). Accepts either user_id or email; at least one
must be provided. When both are provided, ``email`` takes precedence.
"""
if email:
user = await get_user_by_email(email)
if not user:
raise HTTPException(
status_code=404, detail="No user found with the provided email."
)
return user.id, email
if not user_id:
raise HTTPException(
status_code=400,
detail="Either user_id or email query parameter is required.",
)
# We have a user_id; try to look up their email for display purposes.
# This is non-critical -- a failure should not block the response.
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return user_id, resolved_email
@router.get(
"/rate_limit",
response_model=UserRateLimitResponse,
summary="Get User Rate Limit",
)
async def get_user_rate_limit(
user_id: str,
user_id: Optional[str] = None,
email: Optional[str] = None,
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Get a user's current usage and effective rate limits. Admin-only."""
logger.info(f"Admin {admin_user_id} checking rate limit for user {user_id}")
"""Get a user's current usage and effective rate limits. Admin-only.
Accepts either ``user_id`` or ``email`` as a query parameter.
When ``email`` is provided the user is looked up by email first.
"""
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
return UserRateLimitResponse(
user_id=user_id,
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
@@ -70,8 +113,10 @@ async def reset_user_rate_limit(
) -> UserRateLimitResponse:
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
logger.info(
f"Admin {admin_user_id} resetting rate limit for user {user_id} "
f"(reset_weekly={reset_weekly})"
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
admin_user_id,
user_id,
reset_weekly,
)
try:
@@ -85,8 +130,15 @@ async def reset_user_rate_limit(
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,

View File

@@ -1,4 +1,5 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
import fastapi
@@ -19,6 +20,8 @@ client = fastapi.testclient.TestClient(app)
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
_TARGET_EMAIL = "target@example.com"
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
@@ -44,12 +47,13 @@ def _mock_usage_status(
)
def test_get_rate_limit(
def _patch_rate_limit_deps(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test getting rate limit and usage for a user."""
daily_used: int = 500_000,
weekly_used: int = 3_000_000,
):
"""Patch the common rate-limit + user-lookup dependencies."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
@@ -58,14 +62,29 @@ def test_get_rate_limit(
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
def test_get_rate_limit(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test getting rate limit and usage for a user."""
_patch_rate_limit_deps(mocker, target_user_id)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
@@ -77,6 +96,50 @@ def test_get_rate_limit(
)
def test_get_rate_limit_by_email(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test looking up rate limits via email instead of user_id."""
_patch_rate_limit_deps(mocker, target_user_id)
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Test that looking up a non-existent email returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
assert response.status_code == 404
def test_get_rate_limit_no_params() -> None:
"""Test that omitting both user_id and email returns 400."""
response = client.get("/admin/rate_limit")
assert response.status_code == 400
def test_reset_user_usage_daily_only(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
@@ -87,16 +150,7 @@ def test_reset_user_usage_daily_only(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=0, weekly_used=3_000_000),
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
response = client.post(
"/admin/rate_limit/reset",
@@ -127,16 +181,7 @@ def test_reset_user_usage_daily_and_weekly(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=0, weekly_used=0),
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
response = client.post(
"/admin/rate_limit/reset",
@@ -175,6 +220,35 @@ def test_reset_user_usage_redis_failure(
assert response.status_code == 500
def test_get_rate_limit_email_lookup_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that failing to resolve a user email degrades gracefully."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection lost"),
)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] is None
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that rate limit admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]

View File

@@ -17,8 +17,6 @@ from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
from .db import get_library_agent_by_graph_id, update_library_agent
logger = logging.getLogger(__name__)
@@ -61,28 +59,17 @@ async def add_graph_to_library(
graph_model: GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""Check existing / restore soft-deleted / create new LibraryAgent."""
if existing := await get_library_agent_by_graph_id(
user_id, graph_model.id, graph_model.version
):
return existing
"""Check existing / restore soft-deleted / create new LibraryAgent.
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
Uses a create-then-catch-UniqueViolationError-then-update pattern on
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
This is more robust than ``upsert`` because Prisma's upsert atomicity
guarantees are not well-documented for all versions.
"""
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
_include = library_agent_include(
user_id, include_nodes=False, include_executions=False
)
if deleted_agent and (deleted_agent.isDeleted or deleted_agent.isArchived):
return await update_library_agent(
deleted_agent.id,
user_id,
is_deleted=False,
is_archived=False,
)
try:
added_agent = await prisma.models.LibraryAgent.prisma().create(
@@ -98,23 +85,32 @@ async def add_graph_to_library(
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(
GraphSettings.from_graph(graph_model).model_dump()
),
"settings": settings_json,
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
include=_include,
)
except prisma.errors.UniqueViolationError:
# Race condition: concurrent request created the row between our
# check and create. Re-read instead of crashing.
existing = await get_library_agent_by_graph_id(
user_id, graph_model.id, graph_model.version
# Already exists — update to restore if previously soft-deleted/archived
added_agent = await prisma.models.LibraryAgent.prisma().update(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
data={
"isDeleted": False,
"isArchived": False,
"settings": settings_json,
},
include=_include,
)
if existing:
return existing
raise # Shouldn't happen, but don't swallow unexpected errors
if added_agent is None:
raise NotFoundError(
f"LibraryAgent for graph #{graph_model.id} "
f"v{graph_model.version} not found after UniqueViolationError"
)
logger.debug(
f"Added graph #{graph_model.id} v{graph_model.version} "

View File

@@ -1,71 +1,80 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.errors
import pytest
from ._add_to_library import add_graph_to_library
@pytest.mark.asyncio
async def test_add_graph_to_library_restores_archived_agent() -> None:
graph_model = MagicMock(id="graph-id", version=2)
archived_agent = MagicMock(id="library-agent-id", isDeleted=False, isArchived=True)
restored_agent = MagicMock(name="LibraryAgentModel")
async def test_add_graph_to_library_create_new_agent() -> None:
"""When no matching LibraryAgent exists, create inserts a new one."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
created_agent = MagicMock(name="CreatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
new=AsyncMock(return_value=None),
),
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.update_library_agent",
new=AsyncMock(return_value=restored_agent),
) as mock_update,
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=archived_agent)
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is restored_agent
mock_update.assert_awaited_once_with(
"library-agent-id",
"user-id",
is_deleted=False,
is_archived=False,
)
mock_prisma.return_value.create.assert_not_called()
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
assert create_data["User"] == {"connect": {"id": "user-id"}}
assert create_data["AgentGraph"] == {
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
}
assert create_data["isCreatedByUser"] is False
assert create_data["useGraphIsActiveVersion"] is False
@pytest.mark.asyncio
async def test_add_graph_to_library_restores_deleted_agent() -> None:
graph_model = MagicMock(id="graph-id", version=2)
deleted_agent = MagicMock(id="library-agent-id", isDeleted=True, isArchived=False)
restored_agent = MagicMock(name="LibraryAgentModel")
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"""UniqueViolationError on create falls back to update."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
updated_agent = MagicMock(name="UpdatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
new=AsyncMock(return_value=None),
),
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.update_library_agent",
new=AsyncMock(return_value=restored_agent),
) as mock_update,
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=deleted_agent)
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
MagicMock(), message="unique constraint"
)
)
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is restored_agent
mock_update.assert_awaited_once_with(
"library-agent-id",
"user-id",
is_deleted=False,
is_archived=False,
)
mock_prisma.return_value.create.assert_not_called()
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {
"userId_agentGraphId_agentGraphVersion": {
"userId": "user-id",
"agentGraphId": "graph-id",
"agentGraphVersion": 2,
}
}
update_data = update_call.kwargs["data"]
assert update_data["isDeleted"] is False
assert update_data["isArchived"] is False

View File

@@ -436,32 +436,53 @@ async def create_library_agent(
async with transaction() as tx:
library_agents = await asyncio.gather(
*(
prisma.models.LibraryAgent.prisma(tx).create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
prisma.models.LibraryAgent.prisma(tx).upsert(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_entry.id,
"agentGraphVersion": graph_entry.version,
}
},
data={
"create": prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
}
}
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
"update": {
"isDeleted": False,
"isArchived": False,
"useGraphIsActiveVersion": True,
"settings": SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -1,4 +1,6 @@
from contextlib import asynccontextmanager
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.enums
import prisma.models
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
async def test_add_agent_to_library(mocker):
await connect()
# Mock the transaction context
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
# Mock data
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
@@ -143,13 +141,11 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Mock graph_db.get_graph function that's called to check for HITL blocks
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
# (lives in _add_to_library.py after refactor, not db.py)
mock_graph_db = mocker.patch(
"backend.api.features.library._add_to_library.graph_db"
@@ -175,37 +171,27 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args
assert create_call_args is not None
# Verify the main structure
expected_data = {
# Verify the create data structure
create_data = create_call_args.kwargs["data"]
expected_create = {
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
}
actual_data = create_call_args[1]["data"]
# Check that all expected fields are present
for key, value in expected_data.items():
assert actual_data[key] == value
for key, value in expected_create.items():
assert create_data[key] == value
# Check that settings field is present and is a SafeJson object
assert "settings" in actual_data
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
assert "settings" in create_data
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
# Check include parameter
assert create_call_args[1]["include"] == library_agent_include(
assert create_call_args.kwargs["include"] == library_agent_include(
"test-user", include_nodes=False, include_executions=False
)
@@ -320,3 +306,50 @@ async def test_update_graph_in_library_allows_archived_library_agent(mocker):
include_archived=True,
)
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
@pytest.mark.asyncio
async def test_create_library_agent_uses_upsert():
"""create_library_agent should use upsert (not create) to handle duplicates."""
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.version = 1
mock_graph.user_id = "user-1"
mock_graph.nodes = []
mock_graph.sub_graphs = []
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
@asynccontextmanager
async def fake_tx():
yield None
with (
patch("backend.api.features.library.db.transaction", fake_tx),
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
patch(
"backend.api.features.library.db.add_generated_agent_image",
new=AsyncMock(),
),
patch(
"backend.api.features.library.model.LibraryAgent.from_db",
return_value=MagicMock(),
),
):
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
result = await db.create_library_agent(mock_graph, "user-1")
assert len(result) == 1
upsert_call = mock_prisma.return_value.upsert.call_args
assert upsert_call is not None
# Verify the upsert where clause uses the composite unique key
where = upsert_call.kwargs["where"]
assert "userId_agentGraphId_agentGraphVersion" in where
# Verify the upsert data has both create and update branches
data = upsert_call.kwargs["data"]
assert "create" in data
assert "update" in data
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False

View File

@@ -1,3 +1,4 @@
import re
from typing import Any
from backend.blocks._base import (
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
# A true/false answer fits comfortably within this budget.
MIN_LLM_OUTPUT_TOKENS = 16
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
"""Parse an LLM response into a boolean result.
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
response is unambiguous; otherwise it contains a diagnostic message
and *result* defaults to ``False``.
"""
text = response_text.strip().lower()
if text == "true":
return True, None
if text == "false":
return False, None
# Fuzzy match use word boundaries to avoid false positives like "untrue".
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
return True, None
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
return False, None
return False, f"Unclear AI response: '{response_text}'"
class AIConditionBlock(AIBlockBase):
"""
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
]
# Call the LLM
try:
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=10, # We only expect a true/false response
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=MIN_LLM_OUTPUT_TOKENS,
)
# Extract the boolean result from the response
result, error = _parse_boolean_response(response.response)
if error:
yield "error", error
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
# Extract the boolean result from the response
response_text = response.response.strip().lower()
if response_text == "true":
result = True
elif response_text == "false":
result = False
else:
# If the response is not clear, try to interpret it using word boundaries
import re
# Use word boundaries to avoid false positives like 'untrue' or '10'
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
result = True
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
result = False
else:
# Unclear or conflicting response - default to False and yield error
result = False
yield "error", f"Unclear AI response: '{response.response}'"
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
)
self.prompt = response.prompt
except Exception as e:
# In case of any error, default to False to be safe
result = False
# Log the error but don't fail the block execution
import logging
logger = logging.getLogger(__name__)
logger.error(f"AI condition evaluation failed: {str(e)}")
yield "error", f"AI evaluation failed: {str(e)}"
)
self.prompt = response.prompt
# Yield results
yield "result", result

View File

@@ -0,0 +1,147 @@
"""Tests for AIConditionBlock regression coverage for max_tokens and error propagation."""
from __future__ import annotations
from typing import cast
import pytest
from backend.blocks.ai_condition import (
MIN_LLM_OUTPUT_TOKENS,
AIConditionBlock,
_parse_boolean_response,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AICredentials,
LLMResponse,
)
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
# ---------------------------------------------------------------------------
# Helper to collect all yields from the async generator
# ---------------------------------------------------------------------------
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
outputs: dict[str, object] = {}
async for name, value in block.run(input_data, credentials=credentials):
outputs[name] = value
return outputs
def _make_input(**overrides) -> AIConditionBlock.Input:
defaults: dict = {
"input_value": "hello@example.com",
"condition": "the input is an email address",
"yes_value": "yes!",
"no_value": "no!",
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return AIConditionBlock.Input(**defaults)
def _mock_llm_response(response_text: str) -> LLMResponse:
return LLMResponse(
raw_response="",
prompt=[],
response=response_text,
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
)
# ---------------------------------------------------------------------------
# _parse_boolean_response unit tests
# ---------------------------------------------------------------------------
class TestParseBooleanResponse:
def test_true_exact(self):
assert _parse_boolean_response("true") == (True, None)
def test_false_exact(self):
assert _parse_boolean_response("false") == (False, None)
def test_true_with_whitespace(self):
assert _parse_boolean_response(" True ") == (True, None)
def test_yes_fuzzy(self):
assert _parse_boolean_response("Yes") == (True, None)
def test_no_fuzzy(self):
assert _parse_boolean_response("no") == (False, None)
def test_one_fuzzy(self):
assert _parse_boolean_response("1") == (True, None)
def test_zero_fuzzy(self):
assert _parse_boolean_response("0") == (False, None)
def test_unclear_response(self):
result, error = _parse_boolean_response("I'm not sure")
assert result is False
assert error is not None
assert "Unclear" in error
def test_conflicting_tokens(self):
result, error = _parse_boolean_response("true and false")
assert result is False
assert error is not None
# ---------------------------------------------------------------------------
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
# ---------------------------------------------------------------------------
class TestMaxTokensRegression:
@pytest.mark.asyncio
async def test_llm_call_receives_min_output_tokens(self):
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) the previous value
of 1 was too low and caused OpenAI to reject the request."""
block = AIConditionBlock()
captured_kwargs: dict = {}
async def spy_llm_call(**kwargs):
captured_kwargs.update(kwargs)
return _mock_llm_response("true")
block.llm_call = spy_llm_call # type: ignore[assignment]
input_data = _make_input()
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
assert captured_kwargs["max_tokens"] == 16
# ---------------------------------------------------------------------------
# Regression: exceptions from llm_call must propagate
# ---------------------------------------------------------------------------
class TestExceptionPropagation:
@pytest.mark.asyncio
async def test_llm_call_exception_propagates(self):
"""If llm_call raises, the exception must NOT be swallowed.
Previously the block caught all exceptions and silently returned
result=False."""
block = AIConditionBlock()
async def boom(**kwargs):
raise RuntimeError("LLM provider error")
block.llm_call = boom # type: ignore[assignment]
input_data = _make_input()
with pytest.raises(RuntimeError, match="LLM provider error"):
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)

View File

@@ -104,6 +104,18 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
if isinstance(value, str) and "/" in value:
stripped = value.split("/", 1)[1]
try:
return cls(stripped)
except ValueError:
return None
return None
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"

View File

@@ -207,6 +207,51 @@ class TestXMLParserBlockSecurity:
pass
class TestXMLParserBlockSyntaxErrors:
"""XML syntax errors should raise ValueError (not SyntaxError).
This ensures the base Block.execute() wraps them as BlockExecutionError
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
Sentry).
"""
async def test_unclosed_tag_raises_value_error(self):
"""Unclosed tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "<root><unclosed>"
with pytest.raises(ValueError, match="Unclosed tag"):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_unexpected_closing_tag_raises_value_error(self):
"""Extra closing tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "</unexpected>"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_empty_xml_raises_value_error(self):
"""Empty XML input should raise ValueError."""
block = XMLParserBlock()
with pytest.raises(ValueError, match="XML input is empty"):
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
pass
async def test_syntax_error_from_parser_becomes_value_error(self):
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
block = XMLParserBlock()
# Malformed XML that might trigger a SyntaxError from the parser
bad_xml = "<root><child>no closing"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -809,3 +809,33 @@ class TestUserErrorStatusCodeHandling:
mock_warning.assert_called_once()
mock_exception.assert_not_called()
class TestLlmModelMissing:
"""Test that LlmModel handles provider-prefixed model names."""
def test_provider_prefixed_model_resolves(self):
"""Provider-prefixed model string should resolve to the correct enum member."""
assert (
llm.LlmModel("anthropic/claude-sonnet-4-6")
== llm.LlmModel.CLAUDE_4_6_SONNET
)
def test_bare_model_still_works(self):
"""Bare (non-prefixed) model string should still resolve correctly."""
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
def test_invalid_prefixed_model_raises(self):
"""Unknown provider-prefixed model string should raise ValueError."""
with pytest.raises(ValueError):
llm.LlmModel("invalid/nonexistent-model")
def test_slash_containing_value_direct_lookup(self):
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
def test_double_prefixed_slash_model(self):
"""Double-prefixed value should still resolve by stripping first prefix."""
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)

View File

@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
raise ValueError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
raise ValueError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
@@ -76,4 +76,7 @@ class XMLParserBlock(Block):
except ValueError as val_e:
raise ValueError(f"Validation error for dict:{val_e}") from val_e
except SyntaxError as syn_e:
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
# Raise as ValueError so the base Block.execute() wraps it as
# BlockExecutionError (expected user-caused failure) instead of
# BlockUnknownError (unexpected platform error that alerts Sentry).
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e

View File

@@ -25,24 +25,64 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
Use this helper in any copilot SDK test that needs a well-formed
transcript without hitting the real storage layer.
Delegates to ``build_structured_transcript`` — plain content strings
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
assistant messages.
"""
# Cast widening: tuple[str, str] is structurally compatible with
# tuple[str, str | list[dict]] but list invariance requires explicit
# annotation.
widened: list[tuple[str, str | list[dict]]] = list(pairs)
return build_structured_transcript(widened)
def build_structured_transcript(
entries: list[tuple[str, str | list[dict]]],
) -> str:
"""Build a JSONL transcript with structured content blocks.
Each entry is (role, content) where content is either a plain string
(for user messages) or a list of content block dicts (for assistant
messages with thinking/tool_use/text blocks).
Example::
build_structured_transcript([
("user", "Hello"),
("assistant", [
{"type": "thinking", "thinking": "...", "signature": "sig1"},
{"type": "text", "text": "Hi there"},
]),
])
"""
lines: list[str] = []
last_uuid: str | None = None
for role, content in pairs:
for role, content in entries:
uid = str(uuid4())
entry_type = "assistant" if role == "assistant" else "user"
msg: dict = {"role": role, "content": content}
if role == "assistant":
msg.update(
{
"model": "",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
)
if role == "assistant" and isinstance(content, list):
msg: dict = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": content,
"stop_reason": "end_turn",
"stop_sequence": None,
}
elif role == "assistant":
msg = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
else:
msg = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,

View File

@@ -442,8 +442,11 @@ class TestCompactTranscript:
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert msgs[1]["content"] == "Summarized response"
# The last assistant entry is preserved verbatim from original
assert msgs[2]["content"] == "Details"
@pytest.mark.asyncio
async def test_returns_none_on_compression_failure(self, mock_chat_config):

View File

@@ -15,6 +15,7 @@ from claude_agent_sdk import (
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
@@ -100,6 +101,11 @@ class SDKResponseAdapter:
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ThinkingBlock):
# Thinking blocks are preserved in the transcript but
# not streamed to the frontend — skip silently.
pass
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)

View File

@@ -124,8 +124,11 @@ class TestScenarioCompactAndRetry:
assert result != original # Must be different
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert msgs[0]["content"] == "[summary of conversation]"
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Long answer 2"
def test_compacted_transcript_loads_into_builder(self):
"""TranscriptBuilder can load a compacted transcript and continue."""
@@ -737,7 +740,10 @@ class TestRetryEdgeCases:
assert result is not None
assert result != transcript
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Answer 19"
def test_messages_to_transcript_roundtrip_preserves_content(self):
"""Verify messages → transcript → messages preserves all content."""

View File

@@ -185,6 +185,24 @@ def _is_prompt_too_long(err: BaseException) -> bool:
return False
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
Two known patterns occur when ``GeneratorExit`` tears down the async
generator and the SDK's ``__aexit__`` runs in a different context/task:
* ``RuntimeError``: cancel scope exited in wrong task (anyio)
* ``ValueError``: ContextVar token created in a different Context (OTEL)
These are suppressed to avoid polluting Sentry with non-actionable noise.
"""
if isinstance(exc, RuntimeError) and "cancel scope" in str(exc):
return True
if isinstance(exc, ValueError) and "was created in a different Context" in str(exc):
return True
return False
def _is_tool_only_message(sdk_msg: object) -> bool:
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
@@ -409,6 +427,63 @@ _HEARTBEAT_INTERVAL = 10.0 # seconds
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
async def _safe_close_sdk_client(
sdk_client: ClaudeSDKClient,
log_prefix: str,
) -> None:
"""Close a ClaudeSDKClient, suppressing errors from client disconnect.
When the SSE client disconnects mid-stream, ``GeneratorExit`` propagates
through the async generator stack and causes ``ClaudeSDKClient.__aexit__``
to run in a different async context or task than where the client was
opened. This triggers two known error classes:
* ``ValueError``: ``<Token var=<ContextVar name='current_context'>>
was created in a different Context`` — OpenTelemetry's
``context.detach()`` fails because the OTEL context token was
created in the original generator coroutine but detach runs in
the GC / cleanup coroutine (Sentry: AUTOGPT-SERVER-8BT).
* ``RuntimeError``: ``Attempted to exit cancel scope in a different
task than it was entered in`` — anyio's ``TaskGroup.__aexit__``
detects that the cancel scope was entered in one task but is
being exited in another (Sentry: AUTOGPT-SERVER-8BW).
Both are harmless — the TCP connection is already dead and no
resources leak. Logging them at ``debug`` level keeps observability
without polluting Sentry.
"""
try:
await sdk_client.__aexit__(None, None, None)
except (ValueError, RuntimeError) as exc:
if _is_sdk_disconnect_error(exc):
# Expected during client disconnect — suppress to avoid Sentry noise.
logger.debug(
"%s SDK client cleanup error suppressed (client disconnect): %s: %s",
log_prefix,
type(exc).__name__,
exc,
)
else:
raise
except GeneratorExit:
# GeneratorExit can propagate through __aexit__ — suppress it here
# since the generator is already being torn down.
logger.debug(
"%s SDK client cleanup GeneratorExit suppressed (client disconnect)",
log_prefix,
)
except Exception:
# Unexpected cleanup error — log at error level so Sentry captures it
# (via its logging integration), but don't propagate since we're in
# teardown and the caller cannot meaningfully handle this.
logger.error(
"%s Unexpected SDK client cleanup error",
log_prefix,
exc_info=True,
)
async def _iter_sdk_messages(
client: ClaudeSDKClient,
) -> AsyncGenerator[Any, None]:
@@ -595,7 +670,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"""Convert SDK content blocks to transcript format.
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
Unknown block types are logged and skipped.
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
a typed class for) are passed through verbatim to preserve them in the
transcript. Unknown typed block objects are logged and skipped.
"""
result: list[dict[str, Any]] = []
for block in blocks or []:
@@ -627,6 +704,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"signature": block.signature,
}
)
elif isinstance(block, dict) and "type" in block:
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
result.append(block)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
@@ -1188,7 +1268,17 @@ async def _run_stream_attempt(
consecutive_empty_tool_calls = 0
async with ClaudeSDKClient(options=state.options) as client:
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
# suppress SDK cleanup errors that occur when the SSE client disconnects
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
# different async context/task than where the client was opened, which
# triggers:
# - ValueError: ContextVar token mismatch (AUTOGPT-SERVER-8BT)
# - RuntimeError: cancel scope in wrong task (AUTOGPT-SERVER-8BW)
# Both are harmless — the TCP connection is already dead.
sdk_client = ClaudeSDKClient(options=state.options)
client = await sdk_client.__aenter__()
try:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
@@ -1448,6 +1538,8 @@ async def _run_stream_attempt(
if acc.stream_completed:
break
finally:
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
# --- Post-stream processing (only on success) ---
if state.adapter.has_unresolved_tool_calls:
@@ -2169,9 +2261,16 @@ async def stream_chat_completion_sdk(
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
# SDK cleanup errors are expected during client disconnect —
# log as warning rather than error to reduce Sentry noise.
# These are normally caught by _safe_close_sdk_client but
# can escape in edge cases (e.g. GeneratorExit timing).
if _is_sdk_disconnect_error(e):
logger.warning(
"%s SDK cleanup error (client disconnect): %s",
log_prefix,
error_msg,
)
else:
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
@@ -2193,10 +2292,11 @@ async def stream_chat_completion_sdk(
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
is_cancellation = isinstance(e, asyncio.CancelledError) or (
isinstance(e, RuntimeError) and "cancel scope" in str(e)
)
# Skip for CancelledError and SDK disconnect cleanup errors — these
# are not actionable by the user and the SSE connection is already dead.
is_cancellation = isinstance(
e, asyncio.CancelledError
) or _is_sdk_disconnect_error(e)
if not is_cancellation:
yield StreamError(errorText=display_msg, code=code)

View File

@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .service import _prepare_file_attachments, _resolve_sdk_model
from .service import (
_is_sdk_disconnect_error,
_prepare_file_attachments,
_resolve_sdk_model,
_safe_close_sdk_client,
)
@dataclass
@@ -499,3 +504,111 @@ class TestResolveSdkModel:
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _resolve_sdk_model() == "claude-opus-4-6"
# ---------------------------------------------------------------------------
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
# ---------------------------------------------------------------------------
class TestIsSdkDisconnectError:
"""Tests for _is_sdk_disconnect_error — identifies expected SDK cleanup errors."""
def test_cancel_scope_runtime_error(self):
"""RuntimeError about cancel scope in wrong task is a disconnect error."""
exc = RuntimeError(
"Attempted to exit cancel scope in a different task than it was entered in"
)
assert _is_sdk_disconnect_error(exc) is True
def test_context_var_value_error(self):
"""ValueError about ContextVar token mismatch is a disconnect error."""
exc = ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
assert _is_sdk_disconnect_error(exc) is True
def test_unrelated_runtime_error(self):
"""Unrelated RuntimeError should NOT be classified as disconnect error."""
exc = RuntimeError("something else went wrong")
assert _is_sdk_disconnect_error(exc) is False
def test_unrelated_value_error(self):
"""Unrelated ValueError should NOT be classified as disconnect error."""
exc = ValueError("invalid argument")
assert _is_sdk_disconnect_error(exc) is False
def test_other_exception_types(self):
"""Non-RuntimeError/ValueError should NOT be classified as disconnect error."""
assert _is_sdk_disconnect_error(TypeError("bad type")) is False
assert _is_sdk_disconnect_error(OSError("network down")) is False
assert _is_sdk_disconnect_error(asyncio.CancelledError()) is False
# ---------------------------------------------------------------------------
# _safe_close_sdk_client — suppress cleanup errors during disconnect
# ---------------------------------------------------------------------------
class TestSafeCloseSdkClient:
"""Tests for _safe_close_sdk_client — suppresses expected SDK cleanup errors."""
@pytest.mark.asyncio
async def test_clean_exit(self):
"""Normal __aexit__ (no error) should succeed silently."""
client = AsyncMock()
client.__aexit__ = AsyncMock(return_value=None)
await _safe_close_sdk_client(client, "[test]")
client.__aexit__.assert_awaited_once_with(None, None, None)
@pytest.mark.asyncio
async def test_cancel_scope_runtime_error_suppressed(self):
"""RuntimeError from cancel scope mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=RuntimeError(
"Attempted to exit cancel scope in a different task"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_context_var_value_error_suppressed(self):
"""ValueError from ContextVar token mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unexpected_exception_suppressed_with_error_log(self):
"""Unexpected exceptions should be caught (not propagated) but logged at error."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=OSError("unexpected"))
# Should NOT raise — unexpected errors are also suppressed to
# avoid crashing the generator during teardown. Logged at error
# level so Sentry captures them via its logging integration.
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_runtime_error_propagates(self):
"""Non-cancel-scope RuntimeError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=RuntimeError("something unrelated"))
with pytest.raises(RuntimeError, match="something unrelated"):
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_value_error_propagates(self):
"""Non-disconnect ValueError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
with pytest.raises(ValueError, match="invalid argument"):
await _safe_close_sdk_client(client, "[test]")

View File

@@ -0,0 +1,822 @@
"""Tests for thinking/redacted_thinking block preservation.
Validates the fix for the Anthropic API error:
"thinking or redacted_thinking blocks in the latest assistant message
cannot be modified. These blocks must remain as they were in the
original response."
The API requires that thinking blocks in the LAST assistant message are
preserved value-identical. Older assistant messages may have thinking blocks
stripped entirely. This test suite covers:
1. _flatten_assistant_content — strips thinking from older messages
2. compact_transcript — preserves last assistant's thinking blocks
3. response_adapter — handles ThinkingBlock without error
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
from backend.copilot.response_model import (
StreamStartStep,
StreamTextDelta,
StreamTextStart,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import (
_find_last_assistant_entry,
_flatten_assistant_content,
_messages_to_transcript,
_rechain_tail,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
# ---------------------------------------------------------------------------
# Fixtures: realistic thinking block content
# ---------------------------------------------------------------------------
THINKING_BLOCK = {
"type": "thinking",
"thinking": "Let me analyze the user's request carefully...",
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
}
REDACTED_THINKING_BLOCK = {
"type": "redacted_thinking",
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
}
def _make_thinking_transcript() -> str:
"""Build a transcript with thinking blocks in multiple assistant turns.
Layout:
User 1 → Assistant 1 (thinking + text + tool_use)
User 2 (tool_result) → Assistant 2 (thinking + text)
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
"""
return build_structured_transcript(
[
("user", "What files are in this project?"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "I should list the files.",
"signature": "sig_old_1",
},
{"type": "text", "text": "Let me check the files."},
{
"type": "tool_use",
"id": "tu1",
"name": "list_files",
"input": {"path": "/"},
},
],
),
("user", "Here are the files: a.py, b.py"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "Good, I see two Python files.",
"signature": "sig_old_2",
},
{"type": "text", "text": "I found a.py and b.py."},
],
),
("user", "Tell me about a.py"),
(
"assistant",
[
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "a.py contains the main entry point."},
],
),
]
)
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
"""Extract the content blocks of the last assistant entry in a transcript."""
last_content = None
for line in transcript_jsonl.strip().split("\n"):
entry = json.loads(line)
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_content = msg.get("content")
return last_content
# ---------------------------------------------------------------------------
# _find_last_assistant_entry — unit tests
# ---------------------------------------------------------------------------
class TestFindLastAssistantEntry:
def test_splits_at_last_assistant(self):
"""Prefix contains everything before last assistant; tail starts at it."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
assert len(prefix) == 3
assert len(tail) == 1
def test_no_assistant_returns_all_in_prefix(self):
"""When there's no assistant, all lines are in prefix, tail is empty."""
transcript = build_structured_transcript(
[("user", "Hello"), ("user", "Another question")]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 2
assert tail == []
def test_assistant_at_index_zero(self):
"""When assistant is the first entry, prefix is empty."""
transcript = build_structured_transcript(
[("assistant", [{"type": "text", "text": "Start"}])]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert prefix == []
assert len(tail) == 1
def test_trailing_user_included_in_tail(self):
"""User message after last assistant is part of the tail."""
transcript = build_structured_transcript(
[
("user", "Q1"),
("assistant", [{"type": "text", "text": "A1"}]),
("user", "Q2"),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # first user
assert len(tail) == 2 # last assistant + trailing user
def test_multi_entry_turn_fully_preserved(self):
"""An assistant turn spanning multiple JSONL entries (same message.id)
must be entirely in the tail, not split across prefix and tail."""
# Build manually because build_structured_transcript generates unique ids
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [THINKING_BLOCK],
"stop_reason": None,
"stop_sequence": None,
},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {},
},
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
# Both assistant entries share msg_same_turn → both in tail
assert len(prefix) == 1 # only the user entry
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
def test_no_message_id_preserves_last_assistant(self):
"""When the last assistant entry has no message.id, it should still
be preserved in the tail (fail closed) rather than being compressed."""
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # user entry
assert len(tail) == 1 # assistant entry preserved
# ---------------------------------------------------------------------------
# _rechain_tail — UUID chain patching
# ---------------------------------------------------------------------------
class TestRechainTail:
def test_patches_first_entry_parentuuid(self):
"""First tail entry's parentUuid should point to last prefix uuid."""
prefix = _messages_to_transcript(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
)
# Get the last uuid from the prefix
last_prefix_uuid = None
for line in prefix.strip().split("\n"):
entry = json.loads(line)
last_prefix_uuid = entry.get("uuid")
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "tail-a1",
"parentUuid": "old-parent",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Tail msg"}],
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["parentUuid"] == last_prefix_uuid
assert entry["uuid"] == "tail-a1" # uuid preserved
def test_chains_multiple_tail_entries(self):
"""Subsequent tail entries chain to each other."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old1",
"message": {"role": "assistant", "content": []},
}
),
json.dumps(
{
"type": "user",
"uuid": "t2",
"parentUuid": "old2",
"message": {"role": "user", "content": "Follow-up"},
}
),
]
result = _rechain_tail(prefix, tail_lines)
entries = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(entries) == 2
# Second entry's parentUuid should be first entry's uuid
assert entries[1]["parentUuid"] == "t1"
def test_empty_tail_returns_empty(self):
"""No tail entries → empty string."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
assert _rechain_tail(prefix, []) == ""
def test_preserves_message_content_verbatim(self):
"""Tail message content (including thinking blocks) must not be modified."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
original_content = [
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "Response"},
]
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old",
"message": {
"role": "assistant",
"content": original_content,
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["message"]["content"] == original_content
# ---------------------------------------------------------------------------
# _flatten_assistant_content — thinking blocks
# ---------------------------------------------------------------------------
class TestFlattenThinkingBlocks:
def test_thinking_blocks_are_stripped(self):
"""Thinking blocks should not appear in flattened text for compression."""
blocks = [
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
{"type": "text", "text": "Hello user"},
]
result = _flatten_assistant_content(blocks)
assert "secret thoughts" not in result
assert "Hello user" in result
def test_redacted_thinking_blocks_are_stripped(self):
"""Redacted thinking blocks should not appear in flattened text."""
blocks = [
{"type": "redacted_thinking", "data": "encrypted_data"},
{"type": "text", "text": "Response text"},
]
result = _flatten_assistant_content(blocks)
assert "encrypted_data" not in result
assert "Response text" in result
def test_thinking_only_message_flattens_to_empty(self):
"""A message with only thinking blocks flattens to empty string."""
blocks = [
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
]
result = _flatten_assistant_content(blocks)
assert result == ""
def test_mixed_thinking_text_tool(self):
"""Mixed blocks: only text and tool_use survive flattening."""
blocks = [
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
{"type": "redacted_thinking", "data": "xyz"},
{"type": "text", "text": "I'll read the file."},
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
]
result = _flatten_assistant_content(blocks)
assert "hmm" not in result
assert "xyz" not in result
assert "I'll read the file." in result
assert "[tool_use: Read]" in result
# ---------------------------------------------------------------------------
# compact_transcript — thinking block preservation
# ---------------------------------------------------------------------------
class TestCompactTranscriptThinkingBlocks:
"""Verify that compact_transcript preserves thinking blocks in the
last assistant message while stripping them from older messages."""
@pytest.mark.asyncio
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
"""After compaction, the last assistant entry must retain its
original thinking and redacted_thinking blocks verbatim."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
last_content = _last_assistant_content(result)
assert last_content is not None, "No assistant entry found"
assert isinstance(last_content, list)
# The last assistant must have the thinking blocks preserved
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block missing from last assistant message"
assert (
"redacted_thinking" in block_types
), "redacted_thinking block missing from last assistant message"
assert "text" in block_types
# Verify the thinking block content is value-identical
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
assert len(thinking_blocks) == 1
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
assert len(redacted_blocks) == 1
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
@pytest.mark.asyncio
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
"""Older assistant messages should NOT retain thinking blocks
after compaction (they're compressed into summaries)."""
transcript = _make_thinking_transcript()
# The compressor will receive messages where older assistant
# entries have already had thinking blocks stripped.
captured_messages: list[dict] = []
async def mock_compression(messages, model, log_prefix):
captured_messages.extend(messages)
return type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": messages,
"original_token_count": 800,
"token_count": 400,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
side_effect=mock_compression,
):
await compact_transcript(transcript, model="test-model")
# Check that the messages sent to compression don't contain
# thinking content from older assistant messages
for msg in captured_messages:
if msg["role"] == "assistant":
content = msg.get("content", "")
assert (
"I should list the files." not in content
), "Old thinking block content leaked into compression input"
assert (
"Good, I see two Python files." not in content
), "Old thinking block content leaked into compression input"
@pytest.mark.asyncio
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
"""When the last entry is a user message, the last *assistant*
message's thinking blocks should still be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "Hi there"},
],
),
("user", "Follow-up question"),
]
)
# The compressor only receives the prefix (1 user message); the
# tail (assistant + trailing user) is preserved verbatim.
compacted_msgs = [
{"role": "user", "content": "Hello"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 400,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block lost from last assistant despite trailing user msg"
@pytest.mark.asyncio
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
"""When there's only one assistant message (which is also the last),
its thinking blocks must be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "World"},
],
),
]
)
compacted_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert "thinking" in block_types
@pytest.mark.asyncio
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
"""After compaction, the first tail entry's parentUuid must point to
the last entry in the compressed prefix — not its original parent."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
entries = [json.loads(ln) for ln in lines]
# Find the boundary: the compressed prefix ends just before the
# first tail entry (last assistant in original transcript).
tail_start = None
for i, entry in enumerate(entries):
msg = entry.get("message", {})
if isinstance(msg.get("content"), list):
# Structured content = preserved tail entry
tail_start = i
break
assert tail_start is not None, "Could not find preserved tail entry"
assert tail_start > 0, "Tail should not be the first entry"
# The tail entry's parentUuid must be the uuid of the preceding entry
prefix_last_uuid = entries[tail_start - 1]["uuid"]
tail_first_parent = entries[tail_start]["parentUuid"]
assert tail_first_parent == prefix_last_uuid, (
f"Tail parentUuid {tail_first_parent!r} != "
f"last prefix uuid {prefix_last_uuid!r}"
)
@pytest.mark.asyncio
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
"""Compaction should still work normally when there are no thinking
blocks in the transcript."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
compacted_msgs = [
{"role": "user", "content": "[summary]"},
{"role": "assistant", "content": "Summary"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 50,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
# Verify last assistant content is preserved even without thinking blocks
last_content = _last_assistant_content(result)
assert last_content is not None
assert last_content == [{"type": "text", "text": "Details"}]
# ---------------------------------------------------------------------------
# _transcript_to_messages — thinking block handling
# ---------------------------------------------------------------------------
class TestTranscriptToMessagesThinking:
def test_thinking_blocks_excluded_from_flattened_content(self):
"""When _transcript_to_messages flattens content, thinking block
text should not leak into the message content string."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "SECRET_THOUGHT",
"signature": "sig",
},
{"type": "text", "text": "Visible response"},
],
),
]
)
messages = _transcript_to_messages(transcript)
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
assert "SECRET_THOUGHT" not in assistant_msg["content"]
assert "Visible response" in assistant_msg["content"]
# ---------------------------------------------------------------------------
# response_adapter — ThinkingBlock handling
# ---------------------------------------------------------------------------
class TestResponseAdapterThinkingBlock:
def test_thinking_block_does_not_crash(self):
"""ThinkingBlock in AssistantMessage should not cause an error."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="Let me think about this...",
signature="sig_test_123",
),
TextBlock(text="Here is my response."),
],
model="claude-test",
)
results = adapter.convert_message(msg)
# Should produce stream events for text only, no crash
types = [type(r) for r in results]
assert StreamStartStep in types
assert StreamTextStart in types or StreamTextDelta in types
def test_thinking_block_does_not_emit_stream_events(self):
"""ThinkingBlock should NOT produce any StreamTextDelta events
containing thinking content."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="My secret thoughts",
signature="sig_test_456",
),
TextBlock(text="Public response"),
],
model="claude-test",
)
results = adapter.convert_message(msg)
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
for delta in text_deltas:
assert "secret thoughts" not in (delta.delta or "")
# ---------------------------------------------------------------------------
# _format_sdk_content_blocks — redacted_thinking handling
# ---------------------------------------------------------------------------
class TestFormatSdkContentBlocks:
def test_thinking_block_preserved(self):
"""ThinkingBlock should be serialized with type, thinking, and signature."""
blocks = [
ThinkingBlock(thinking="My thoughts", signature="sig123"),
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == {
"type": "thinking",
"thinking": "My thoughts",
"signature": "sig123",
}
assert result[1] == {"type": "text", "text": "Response"}
def test_raw_dict_redacted_thinking_preserved(self):
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
blocks = [
raw_block,
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == raw_block
assert result[1] == {"type": "text", "text": "Response"}

View File

@@ -605,20 +605,31 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
placeholders. This is intentional: ``compress_context`` requires plain
text for token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript was
already too large for the model — a summarized plain-text version is
better than no context at all.
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
silently dropped — they carry no useful context for compression
summaries and must not leak into compacted transcripts (the Anthropic
API requires thinking blocks in the last assistant message to be
value-identical to the original response; including stale thinking
text would violate that constraint).
This is intentional: ``compress_context`` requires plain text for
token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript
was already too large for the model.
"""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
btype = block.get("type", "")
if btype in _THINKING_BLOCK_TYPES:
continue
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
@@ -805,6 +816,68 @@ async def _run_compression(
)
def _find_last_assistant_entry(
content: str,
) -> tuple[list[str], list[str]]:
"""Split JSONL lines into (compressible_prefix, preserved_tail).
The tail starts at the **first** entry of the last assistant turn and
includes everything after it (typically trailing user messages). An
assistant turn can span multiple consecutive JSONL entries sharing the
same ``message.id`` (e.g., a thinking entry followed by a tool_use
entry). All entries of the turn are preserved verbatim.
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
blocks in the **last** assistant message remain value-identical to the
original response (the API validates parsed signature values, not raw
JSON bytes). By excluding the entire turn from compression we
guarantee those blocks are never altered.
Returns ``(all_lines, [])`` when no assistant entry is found.
"""
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
# Parse all lines once to avoid double JSON deserialization.
# json.loads with fallback=None returns Any; non-dict entries are
# safely skipped by the isinstance(entry, dict) guards below.
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
# Reverse scan: find the message.id and index of the last assistant entry.
last_asst_msg_id: str | None = None
last_asst_idx: int | None = None
for i in range(len(parsed) - 1, -1, -1):
entry = parsed[i]
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_asst_idx = i
last_asst_msg_id = msg.get("id")
break
if last_asst_idx is None:
return lines, []
# If the assistant entry has no message.id, fall back to preserving
# from that single entry onward — safer than compressing everything.
if last_asst_msg_id is None:
return lines[:last_asst_idx], lines[last_asst_idx:]
# Forward scan: find the first entry of this turn (same message.id).
first_turn_idx: int | None = None
for i, entry in enumerate(parsed):
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
first_turn_idx = i
break
if first_turn_idx is None:
return lines, []
return lines[:first_turn_idx], lines[first_turn_idx:]
async def compact_transcript(
content: str,
*,
@@ -816,42 +889,50 @@ async def compact_transcript(
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
is flattened to plain text for compression. This matches the fidelity of
the Plan C (DB compression) fallback path, where
``_format_conversation_context`` similarly renders tool calls as
``You called tool: name(args)`` and results as ``Tool result: ...``.
Neither path preserves structured API content blocks — the compacted
context serves as text history for the LLM, which creates proper
structured tool calls going forward.
The **last assistant entry** (and any entries after it) are preserved
verbatim — never flattened or compressed. The Anthropic API requires
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
message to be value-identical to the original response (the API
validates parsed signature values, not raw JSON bytes); compressing
them would destroy the cryptographic signatures and cause
``invalid_request_error``.
Images are per-turn attachments loaded from workspace storage by file ID
(via ``_prepare_file_attachments``), not part of the conversation history.
They are re-attached each turn and are unaffected by compaction.
Structured content in *older* assistant entries (``tool_use`` blocks,
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
plain text for compression. This matches the fidelity of the Plan C
(DB compression) fallback path.
Returns the compacted JSONL string, or ``None`` on failure.
See also:
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
lists for pre-query DB history. Both share ``compress_context()``
but operate on different input formats (JSONL transcript entries
here vs. ChatMessage dicts there).
lists for pre-query DB history.
"""
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
prefix_lines, tail_lines = _find_last_assistant_entry(content)
# Build the JSONL string for the compressible prefix
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
messages = _transcript_to_messages(prefix_content) if prefix_content else []
if len(messages) + len(tail_lines) < 2:
total = len(messages) + len(tail_lines)
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
return None
if not messages:
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
return None
try:
result = await _run_compression(messages, model, log_prefix)
if not result.was_compacted:
# Compressor says it's within budget, but the SDK rejected it.
# Return None so the caller falls through to DB fallback.
logger.warning(
"%s Compressor reports within budget but SDK rejected — "
"signalling failure",
log_prefix,
)
return None
if not result.messages:
logger.warning("%s Compressor returned empty messages", log_prefix)
return None
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
@@ -860,7 +941,29 @@ async def compact_transcript(
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
compressed_part = _messages_to_transcript(result.messages)
# Re-append the preserved tail (last assistant + trailing entries)
# with parentUuid patched to chain onto the compressed prefix.
tail_part = _rechain_tail(compressed_part, tail_lines)
compacted = compressed_part + tail_part
if len(compacted) >= len(content):
# Byte count can increase due to preserved tail entries
# (thinking blocks, JSON overhead) even when token count
# decreased. Log a warning but still return — the API
# validates tokens not bytes, and the caller falls through
# to DB fallback if the transcript is still too large.
logger.warning(
"%s Compacted transcript (%d bytes) is not smaller than "
"original (%d bytes) — may still reduce token count",
log_prefix,
len(compacted),
len(content),
)
# Authoritative validation — the caller (_reduce_context) also
# validates, but this is the canonical check that guarantees we
# never return a malformed transcript from this function.
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
@@ -870,3 +973,43 @@ async def compact_transcript(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
last entry in the compressed prefix. Subsequent tail entries are
rechained to point to their predecessor in the tail — their original
``parentUuid`` values may reference entries that were compressed away.
"""
if not tail_lines:
return ""
# Find the last uuid in the compressed prefix
last_prefix_uuid = ""
for line in reversed(compressed_prefix.strip().split("\n")):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if isinstance(entry, dict) and "uuid" in entry:
last_prefix_uuid = entry["uuid"]
break
result_lines: list[str] = []
prev_uuid: str | None = None
for i, line in enumerate(tail_lines):
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
# Safety guard: _find_last_assistant_entry already filters empty
# lines, and well-formed JSONL always parses to dicts. Non-dict
# lines are passed through unchanged; prev_uuid is intentionally
# NOT updated so the next dict entry chains to the last known uuid.
result_lines.append(line)
continue
if i == 0:
entry["parentUuid"] = last_prefix_uuid
elif prev_uuid is not None:
entry["parentUuid"] = prev_uuid
prev_uuid = entry.get("uuid")
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"

View File

@@ -581,7 +581,6 @@ class GraphModel(Graph, GraphMeta):
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))
@@ -1529,6 +1528,28 @@ async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphMo
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
# Auto-increment version for any graph entry (parent or sub-graph) whose
# (id, version) already exists. This prevents UniqueViolationError when
# the copilot re-saves an agent that already exists at the requested version.
# NOTE: This issues one find_first query per graph entry (N+1 pattern).
# Sub-graph counts are typically small (< 5), so the overhead is negligible.
for g in graphs:
existing = await AgentGraph.prisma(tx).find_first(
where={"id": g.id},
order={"version": "desc"},
)
if existing and existing.version >= g.version:
old_version = g.version
g.version = existing.version + 1
logger.warning(
"Auto-incremented graph %s version from %d to %d "
"(version %d already exists)",
g.id,
old_version,
g.version,
existing.version,
)
await AgentGraph.prisma(tx).create_many(
data=[
AgentGraphCreateInput(

View File

@@ -76,20 +76,24 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to atomically handle concurrent creation attempts.
Args:
user_id: The user's ID
Returns:
Workspace instance
"""
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
if workspace:
return Workspace.from_db(workspace)
try:
workspace = await UserWorkspace.prisma().create(data={"userId": user_id})
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No-op update; workspace already exists
},
)
except UniqueViolationError:
# Concurrent request already created it
# Defense-in-depth: should not happen with upsert, but handle gracefully
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
if workspace is None:
raise
@@ -125,6 +129,13 @@ async def create_workspace_file(
"""
Create a new workspace file record.
Raises ``UniqueViolationError`` if a record with the same
``(workspaceId, path)`` already exists. The caller
(``WorkspaceManager._persist_db_record``) relies on this to trigger
its delete-old-file-then-retry flow, which cleans up the old storage
blob before re-creating the DB record. Using ``upsert`` here would
silently overwrite ``storagePath`` and orphan the old blob in storage.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)

View File

@@ -84,6 +84,14 @@ def _before_send(event, hint):
):
return None
# Prisma UniqueViolationError — always caught and handled in our codebase.
# These arise from concurrent create operations racing on unique constraints
# (workspace files, credits, library folders, store listings, chat messages).
# Every call site has an except handler; the global FastAPI handler also
# catches them and returns 400. Safe to drop unconditionally.
if exc_type and exc_type.__name__ == "UniqueViolationError":
return None
# Google metadata DNS errors — expected in non-GCP environments
if (
"metadata.google.internal" in exc_msg

View File

@@ -227,10 +227,16 @@ class AppService(BaseAppService, ABC):
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: Request, exc: Exception):
if log_error:
logger.error(
f"{request.method} {request.url.path} failed: {exc}",
exc_info=exc if status_code == 500 else None,
)
if status_code >= 500:
logger.error(
f"{request.method} {request.url.path} failed: {exc}",
exc_info=exc,
)
else:
logger.warning(
f"{request.method} {request.url.path} failed: {exc}",
exc_info=exc,
)
return responses.JSONResponse(
status_code=status_code,
content=RemoteCallError(
@@ -563,7 +569,6 @@ def get_service_client(
self._connection_failure_count >= 3
and current_time - self._last_client_reset > 30
):
logger.warning(
f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients"
)

View File

@@ -108,6 +108,9 @@ class VirusScannerService:
return VirusScanResult(
is_clean=True, scan_time_ms=0, file_size=len(content)
)
if len(content) == 0:
logger.debug(f"Skipping virus scan for empty file {filename}")
return VirusScanResult(is_clean=True, scan_time_ms=0, file_size=0)
if len(content) > self.settings.max_scan_size:
logger.warning(
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
@@ -123,7 +126,7 @@ class VirusScannerService:
raise RuntimeError("ClamAV service is unreachable")
start = time.monotonic()
chunk_size = len(content) # Start with full content length
chunk_size = max(1, len(content)) # Start with full content length
for retry in range(self.settings.max_retries):
# For small files, don't check min_chunk_size limit
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
@@ -212,5 +215,5 @@ async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> Non
except VirusDetectedError:
raise
except Exception as e:
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
logger.warning(f"Virus scanning failed for {filename}: {str(e)}")
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e

View File

@@ -85,7 +85,36 @@ class TestVirusScannerService:
)
assert result_dirty.is_clean is False
# Note: ping method was removed from current implementation
@pytest.mark.asyncio
async def test_scan_empty_file(self, scanner):
"""Empty files (0 bytes) should be accepted without hitting ClamAV."""
content = b""
result = await scanner.scan_file(content, filename="empty.png")
assert result.is_clean is True
assert result.threat_name is None
assert result.file_size == 0
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_single_byte_file(self, scanner):
"""A 1-byte file should be scanned normally (regression: chunk_size must not be 0)."""
async def mock_instream(_):
await asyncio.sleep(0.001)
return None
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
scanner._client = mock_client
content = b"\x00"
result = await scanner.scan_file(content, filename="tiny.bin")
assert result.is_clean is True
assert result.file_size == 1
assert result.scan_time_ms > 0
@pytest.mark.asyncio
async def test_scan_clean_file(self, scanner):
@@ -251,3 +280,27 @@ class TestHelperFunctions:
with pytest.raises(VirusScanError, match="Virus scanning failed"):
await scan_content_safe(b"test content", filename="test.txt")
@pytest.mark.asyncio
async def test_scan_content_safe_logs_warning_not_error_on_failure(self):
"""Scan failures should log at WARNING level, not ERROR, to avoid paging on-call."""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.side_effect = Exception(
"range() arg 3 must not be zero"
)
mock_get_scanner.return_value = mock_scanner
with (
pytest.raises(VirusScanError),
patch("backend.util.virus_scanner.logger") as mock_logger,
):
await scan_content_safe(b"test", filename="screenshot.png")
mock_logger.warning.assert_called_once()
# Check the formatted log message contains the error text.
# Use str() to handle both f-string and %-style logging formats.
log_msg = str(mock_logger.warning.call_args)
assert "range()" in log_msg
mock_logger.error.assert_not_called()

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 500000,
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,
"weekly_tokens_used": 3000000

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 0,
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,
"weekly_tokens_used": 0

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 0,
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,
"weekly_tokens_used": 3000000

View File

@@ -25,7 +25,11 @@ Sentry.init({
// Suppress cross-origin stylesheet errors from Sentry Replay (rrweb)
// serializing DOM snapshots with cross-origin stylesheets
// (e.g., from browser extensions or CDN-loaded CSS)
ignoreErrors: [/Not allowed to access cross-origin stylesheet/],
ignoreErrors: [
/Not allowed to access cross-origin stylesheet/,
// Sentry SDK internal issue on some mobile browsers
/Error invoking postEvent: Method not found/,
],
// Add optional integrations for additional features
integrations: [

View File

@@ -0,0 +1,71 @@
"use client";
import { useState } from "react";
import { Input } from "@/components/__legacy__/ui/input";
import { Button } from "@/components/atoms/Button/Button";
import { MagnifyingGlass } from "@phosphor-icons/react";
export interface AdminUserSearchProps {
/** Current search query value (controlled). Falls back to internal state if omitted. */
value?: string;
/** Called when the input text changes */
onChange?: (value: string) => void;
/** Called when the user presses Enter or clicks the search button */
onSearch: (query: string) => void;
/** Placeholder text for the input */
placeholder?: string;
/** Disables the input and button while a search is in progress */
isLoading?: boolean;
}
/**
* Shared admin user search input.
* Supports searching users by name, email, or partial/fuzzy text.
* Can be used as controlled (value + onChange) or uncontrolled (internal state).
*/
export function AdminUserSearch({
value: controlledValue,
onChange,
onSearch,
placeholder = "Search users by Name or Email...",
isLoading = false,
}: AdminUserSearchProps) {
const [internalValue, setInternalValue] = useState("");
const isControlled = controlledValue !== undefined;
const currentValue = isControlled ? controlledValue : internalValue;
function handleChange(newValue: string) {
if (isControlled) {
onChange?.(newValue);
} else {
setInternalValue(newValue);
}
}
function handleSearch() {
onSearch(currentValue.trim());
}
return (
<div className="flex w-full items-center gap-2">
<Input
placeholder={placeholder}
aria-label={placeholder}
value={currentValue}
onChange={(e) => handleChange(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSearch()}
disabled={isLoading}
/>
<Button
variant="outline"
size="small"
onClick={handleSearch}
disabled={isLoading || !currentValue.trim()}
loading={isLoading}
>
{isLoading ? "Searching..." : <MagnifyingGlass size={16} />}
</Button>
</div>
);
}

View File

@@ -0,0 +1,34 @@
"use client";
export function formatTokens(tokens: number): string {
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(0)}K`;
return tokens.toString();
}
export function UsageBar({ used, limit }: { used: number; limit: number }) {
if (limit === 0) {
return <span className="text-sm text-gray-500">Unlimited</span>;
}
const pct = Math.min(Math.max(0, (used / limit) * 100), 100);
const color =
pct >= 90 ? "bg-red-500" : pct >= 70 ? "bg-yellow-500" : "bg-green-500";
return (
<div className="space-y-1">
<div className="flex justify-between text-sm">
<span>{formatTokens(used)} used</span>
<span>{formatTokens(limit)} limit</span>
</div>
<div className="h-2 w-full rounded-full bg-gray-200">
<div
className={`h-2 rounded-full ${color}`}
style={{ width: `${pct}%` }}
/>
</div>
<div className="text-right text-xs text-gray-500">
{pct.toFixed(1)}% used
</div>
</div>
);
}

View File

@@ -3,46 +3,16 @@
import { useState } from "react";
import { Button } from "@/components/atoms/Button/Button";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
function formatTokens(tokens: number): string {
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(0)}K`;
return tokens.toString();
}
function UsageBar({ used, limit }: { used: number; limit: number }) {
if (limit === 0) {
return <span className="text-sm text-gray-500">Unlimited</span>;
}
const pct = Math.min((used / limit) * 100, 100);
const color =
pct >= 90 ? "bg-red-500" : pct >= 70 ? "bg-yellow-500" : "bg-green-500";
return (
<div className="space-y-1">
<div className="flex justify-between text-sm">
<span>{formatTokens(used)} used</span>
<span>{formatTokens(limit)} limit</span>
</div>
<div className="h-2 w-full rounded-full bg-gray-200 dark:bg-gray-700">
<div
className={`h-2 rounded-full ${color}`}
style={{ width: `${pct}%` }}
/>
</div>
<div className="text-right text-xs text-gray-500">
{pct.toFixed(1)}% used
</div>
</div>
);
}
import { UsageBar } from "../../components/UsageBar";
interface Props {
data: UserRateLimitResponse;
onReset: (resetWeekly: boolean) => Promise<void>;
/** Override the outer container classes (default: bordered card). */
className?: string;
}
export function RateLimitDisplay({ data, onReset }: Props) {
export function RateLimitDisplay({ data, onReset, className }: Props) {
const [isResetting, setIsResetting] = useState(false);
const [resetWeekly, setResetWeekly] = useState(false);
@@ -65,25 +35,25 @@ export function RateLimitDisplay({ data, onReset }: Props) {
: data.daily_tokens_used === 0;
return (
<div className="rounded-md border bg-white p-6 dark:bg-gray-900">
<h2 className="mb-4 text-lg font-semibold">
Rate Limits for {data.user_id}
<div className={className ?? "rounded-md border bg-white p-6"}>
<h2 className="mb-1 text-lg font-semibold">
Rate Limits for {data.user_email ?? data.user_id}
</h2>
{data.user_email && (
<p className="mb-4 text-xs text-gray-500">User ID: {data.user_id}</p>
)}
{!data.user_email && <div className="mb-4" />}
<div className="grid grid-cols-2 gap-6">
<div className="space-y-2">
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
Daily Usage
</h3>
<h3 className="text-sm font-medium text-gray-700">Daily Usage</h3>
<UsageBar
used={data.daily_tokens_used}
limit={data.daily_token_limit}
/>
</div>
<div className="space-y-2">
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
Weekly Usage
</h3>
<h3 className="text-sm font-medium text-gray-700">Weekly Usage</h3>
<UsageBar
used={data.weekly_tokens_used}
limit={data.weekly_token_limit}
@@ -93,9 +63,10 @@ export function RateLimitDisplay({ data, onReset }: Props) {
<div className="mt-6 flex items-center gap-3 border-t pt-4">
<select
aria-label="Reset scope"
value={resetWeekly ? "both" : "daily"}
onChange={(e) => setResetWeekly(e.target.value === "both")}
className="rounded-md border bg-white px-3 py-1.5 text-sm dark:bg-gray-800 dark:text-gray-200"
className="rounded-md border bg-white px-3 py-1.5 text-sm"
disabled={isResetting}
>
<option value="daily">Reset daily only</option>

View File

@@ -1,101 +1,78 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/__legacy__/ui/input";
import { Label } from "@/components/__legacy__/ui/label";
import { MagnifyingGlass } from "@phosphor-icons/react";
import { useToast } from "@/components/molecules/Toast/use-toast";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import {
getV2GetUserRateLimit,
postV2ResetUserRateLimitUsage,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { AdminUserSearch } from "../../components/AdminUserSearch";
import { RateLimitDisplay } from "./RateLimitDisplay";
import { useRateLimitManager } from "./useRateLimitManager";
export function RateLimitManager() {
const { toast } = useToast();
const [userIdInput, setUserIdInput] = useState("");
const [isLoading, setIsLoading] = useState(false);
const [rateLimitData, setRateLimitData] =
useState<UserRateLimitResponse | null>(null);
async function handleLookup() {
const trimmed = userIdInput.trim();
if (!trimmed) return;
setIsLoading(true);
try {
const response = await getV2GetUserRateLimit({ user_id: trimmed });
if (response.status !== 200) {
throw new Error("Failed to fetch rate limit");
}
setRateLimitData(response.data);
} catch (error) {
console.error("Error fetching rate limit:", error);
toast({
title: "Error",
description: "Failed to fetch user rate limit. Check the user ID.",
variant: "destructive",
});
setRateLimitData(null);
} finally {
setIsLoading(false);
}
}
async function handleReset(resetWeekly: boolean) {
if (!rateLimitData) return;
try {
const response = await postV2ResetUserRateLimitUsage({
user_id: rateLimitData.user_id,
reset_weekly: resetWeekly,
});
if (response.status !== 200) {
throw new Error("Failed to reset usage");
}
setRateLimitData(response.data);
toast({
title: "Success",
description: resetWeekly
? "Daily and weekly usage reset to zero."
: "Daily usage reset to zero.",
});
} catch (error) {
console.error("Error resetting rate limit:", error);
toast({
title: "Error",
description: "Failed to reset rate limit usage.",
variant: "destructive",
});
}
}
const {
isSearching,
isLoadingRateLimit,
searchResults,
selectedUser,
rateLimitData,
handleSearch,
handleSelectUser,
handleReset,
} = useRateLimitManager();
return (
<div className="space-y-6">
<div className="rounded-md border bg-white p-6 dark:bg-gray-900">
<Label htmlFor="userId" className="mb-2 block text-sm font-medium">
User ID
</Label>
<div className="flex items-center gap-2">
<Input
id="userId"
placeholder="Enter user ID to look up rate limits..."
value={userIdInput}
onChange={(e) => setUserIdInput(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleLookup()}
/>
<Button
variant="outline"
onClick={handleLookup}
disabled={isLoading || !userIdInput.trim()}
>
{isLoading ? "Loading..." : <MagnifyingGlass size={16} />}
</Button>
</div>
<div className="rounded-md border bg-white p-6">
<label className="mb-2 block text-sm font-medium">Search User</label>
<AdminUserSearch
onSearch={handleSearch}
placeholder="Search by name, email, or user ID..."
isLoading={isSearching}
/>
<p className="mt-1.5 text-xs text-gray-500">
Exact email or user ID does a direct lookup. Partial text searches
user history.
</p>
</div>
{/* User selection list -- always require explicit selection */}
{searchResults.length >= 1 && !selectedUser && (
<div className="rounded-md border bg-white p-4">
<h3 className="mb-2 text-sm font-medium text-gray-700">
Select a user ({searchResults.length}{" "}
{searchResults.length === 1 ? "result" : "results"})
</h3>
<ul className="divide-y">
{searchResults.map((user) => (
<li key={user.user_id}>
<button
className="w-full px-2 py-2 text-left text-sm hover:bg-gray-100"
onClick={() => handleSelectUser(user)}
>
<span className="font-medium">{user.user_email}</span>
<span className="ml-2 text-xs text-gray-500">
{user.user_id}
</span>
</button>
</li>
))}
</ul>
</div>
)}
{/* Show selected user */}
{selectedUser && searchResults.length >= 1 && (
<div className="rounded-md border border-blue-200 bg-blue-50 px-4 py-2 text-sm">
Selected:{" "}
<span className="font-medium">{selectedUser.user_email}</span>
<span className="ml-2 text-xs text-gray-500">
{selectedUser.user_id}
</span>
</div>
)}
{isLoadingRateLimit && (
<div className="py-4 text-center text-sm text-gray-500">
Loading rate limits...
</div>
)}
{rateLimitData && (
<RateLimitDisplay data={rateLimitData} onReset={handleReset} />
)}

View File

@@ -0,0 +1,212 @@
"use client";
import { useState } from "react";
import { useToast } from "@/components/molecules/Toast/use-toast";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import {
getV2GetUserRateLimit,
getV2GetAllUsersHistory,
postV2ResetUserRateLimitUsage,
} from "@/app/api/__generated__/endpoints/admin/admin";
export interface UserOption {
user_id: string;
user_email: string;
}
/**
* Returns true when the input looks like a complete email address.
* Used to decide whether to call the direct email lookup endpoint
* vs. the broader user-history search.
*/
function looksLikeEmail(input: string): boolean {
return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(input);
}
/**
* Returns true when the input looks like a UUID (user ID).
*/
function looksLikeUuid(input: string): boolean {
return /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test(
input,
);
}
export function useRateLimitManager() {
const { toast } = useToast();
const [isSearching, setIsSearching] = useState(false);
const [isLoadingRateLimit, setIsLoadingRateLimit] = useState(false);
const [searchResults, setSearchResults] = useState<UserOption[]>([]);
const [selectedUser, setSelectedUser] = useState<UserOption | null>(null);
const [rateLimitData, setRateLimitData] =
useState<UserRateLimitResponse | null>(null);
/** Direct lookup by email or user ID via the rate-limit endpoint. */
async function handleDirectLookup(trimmed: string) {
setIsSearching(true);
setSearchResults([]);
setSelectedUser(null);
setRateLimitData(null);
try {
const params = looksLikeEmail(trimmed)
? { email: trimmed }
: { user_id: trimmed };
const response = await getV2GetUserRateLimit(params);
if (response.status !== 200) {
throw new Error("Failed to fetch rate limit");
}
setRateLimitData(response.data);
setSelectedUser({
user_id: response.data.user_id,
user_email: response.data.user_email ?? response.data.user_id,
});
} catch (error) {
console.error("Error fetching rate limit:", error);
const hint = looksLikeEmail(trimmed)
? "No user found with that email address."
: "Check the user ID and try again.";
toast({
title: "Error",
description: `Failed to fetch rate limits. ${hint}`,
variant: "destructive",
});
setRateLimitData(null);
} finally {
setIsSearching(false);
}
}
/** Fuzzy name/email search via the spending-history endpoint. */
async function handleFuzzySearch(trimmed: string) {
setIsSearching(true);
setSearchResults([]);
setSelectedUser(null);
setRateLimitData(null);
try {
const response = await getV2GetAllUsersHistory({
search: trimmed,
page: 1,
page_size: 50,
});
if (response.status !== 200) {
throw new Error("Failed to search users");
}
// Deduplicate by user_id to get unique users
const seen = new Set<string>();
const users: UserOption[] = [];
for (const tx of response.data.history) {
if (!seen.has(tx.user_id)) {
seen.add(tx.user_id);
users.push({
user_id: tx.user_id,
user_email: String(tx.user_email ?? tx.user_id),
});
}
}
if (users.length === 0) {
toast({
title: "No results",
description: "No users found matching your search.",
});
}
// Always show the result list so the user explicitly picks a match.
// The history endpoint paginates transactions, not users, so a single
// page may not be authoritative -- avoid auto-selecting.
setSearchResults(users);
} catch (error) {
console.error("Error searching users:", error);
toast({
title: "Error",
description: "Failed to search users.",
variant: "destructive",
});
} finally {
setIsSearching(false);
}
}
async function handleSearch(query: string) {
const trimmed = query.trim();
if (!trimmed) return;
// Direct lookup when the input is a full email or UUID.
// This avoids the spending-history indirection and works even for
// users who have no credit transaction history.
if (looksLikeEmail(trimmed) || looksLikeUuid(trimmed)) {
await handleDirectLookup(trimmed);
} else {
await handleFuzzySearch(trimmed);
}
}
async function fetchRateLimit(userId: string) {
setIsLoadingRateLimit(true);
try {
const response = await getV2GetUserRateLimit({ user_id: userId });
if (response.status !== 200) {
throw new Error("Failed to fetch rate limit");
}
setRateLimitData(response.data);
} catch (error) {
console.error("Error fetching rate limit:", error);
toast({
title: "Error",
description: "Failed to fetch user rate limit.",
variant: "destructive",
});
setRateLimitData(null);
} finally {
setIsLoadingRateLimit(false);
}
}
async function handleSelectUser(user: UserOption) {
setSelectedUser(user);
setRateLimitData(null);
await fetchRateLimit(user.user_id);
}
async function handleReset(resetWeekly: boolean) {
if (!rateLimitData) return;
try {
const response = await postV2ResetUserRateLimitUsage({
user_id: rateLimitData.user_id,
reset_weekly: resetWeekly,
});
if (response.status !== 200) {
throw new Error("Failed to reset usage");
}
setRateLimitData(response.data);
toast({
title: "Success",
description: resetWeekly
? "Daily and weekly usage reset to zero."
: "Daily usage reset to zero.",
});
} catch (error) {
console.error("Error resetting rate limit:", error);
toast({
title: "Error",
description: "Failed to reset rate limit usage.",
variant: "destructive",
});
}
}
return {
isSearching,
isLoadingRateLimit,
searchResults,
selectedUser,
rateLimitData,
handleSearch,
handleSelectUser,
handleReset,
};
}

View File

@@ -11,6 +11,7 @@ import { PaginationControls } from "../../../../../components/__legacy__/ui/pagi
import { SearchAndFilterAdminSpending } from "./SearchAndFilterAdminSpending";
import { getUsersTransactionHistory } from "@/app/(platform)/admin/spending/actions";
import { AdminAddMoneyButton } from "./AddMoneyButton";
import { RateLimitModal } from "./RateLimitModal";
import { CreditTransactionType } from "@/lib/autogpt-server-api";
export async function AdminUserGrantHistory({
@@ -80,10 +81,7 @@ export async function AdminUserGrantHistory({
return (
<div className="space-y-4">
<SearchAndFilterAdminSpending
initialStatus={initialStatus}
initialSearch={initialSearch}
/>
<SearchAndFilterAdminSpending initialSearch={initialSearch} />
<div className="rounded-md border bg-white">
<Table>
@@ -105,7 +103,7 @@ export async function AdminUserGrantHistory({
{history.length === 0 ? (
<TableRow>
<TableCell
colSpan={8}
colSpan={9}
className="py-10 text-center text-gray-500"
>
No transactions found
@@ -114,7 +112,7 @@ export async function AdminUserGrantHistory({
) : (
history.map((transaction) => (
<TableRow
key={transaction.user_id}
key={`${transaction.user_id}-${transaction.transaction_time}`}
className="hover:bg-gray-50"
>
<TableCell className="font-medium">
@@ -147,25 +145,29 @@ export async function AdminUserGrantHistory({
${transaction.current_balance / 100}
</TableCell> */}
<TableCell className="text-right">
<AdminAddMoneyButton
userId={transaction.user_id}
userEmail={
transaction.user_email ?? "User Email wasn't attached"
}
currentBalance={transaction.current_balance}
defaultAmount={
transaction.transaction_type ===
CreditTransactionType.USAGE
? -transaction.amount
: undefined
}
defaultComments={
transaction.transaction_type ===
CreditTransactionType.USAGE
? "Refund for usage"
: undefined
}
/>
<div className="flex items-center justify-end gap-2">
<RateLimitModal
userId={transaction.user_id}
userEmail={transaction.user_email ?? ""}
/>
<AdminAddMoneyButton
userId={transaction.user_id}
userEmail={transaction.user_email ?? ""}
currentBalance={transaction.current_balance}
defaultAmount={
transaction.transaction_type ===
CreditTransactionType.USAGE
? -transaction.amount
: undefined
}
defaultComments={
transaction.transaction_type ===
CreditTransactionType.USAGE
? "Refund for usage"
: undefined
}
/>
</div>
</TableCell>
</TableRow>
))

View File

@@ -0,0 +1,138 @@
"use client";
import { useState, useEffect } from "react";
import { Button } from "@/components/atoms/Button/Button";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
} from "@/components/__legacy__/ui/dialog";
import { useToast } from "@/components/molecules/Toast/use-toast";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import {
getV2GetUserRateLimit,
postV2ResetUserRateLimitUsage,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { Gauge } from "@phosphor-icons/react";
import { RateLimitDisplay } from "../../rate-limits/components/RateLimitDisplay";
export function RateLimitModal({
userId,
userEmail,
}: {
userId: string;
userEmail: string;
}) {
const { toast } = useToast();
const [open, setOpen] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [rateLimitData, setRateLimitData] =
useState<UserRateLimitResponse | null>(null);
useEffect(() => {
if (!open) {
setRateLimitData(null);
return;
}
async function fetchRateLimit() {
setIsLoading(true);
try {
const response = await getV2GetUserRateLimit({ user_id: userId });
if (response.status !== 200) {
throw new Error("Failed to fetch rate limit");
}
setRateLimitData(response.data);
} catch (error) {
console.error("Error fetching rate limit:", error);
toast({
title: "Error",
description: "Failed to fetch user rate limit.",
variant: "destructive",
});
setRateLimitData(null);
} finally {
setIsLoading(false);
}
}
fetchRateLimit();
}, [open, userId, toast]);
async function handleReset(resetWeekly: boolean) {
if (!rateLimitData) return;
try {
const response = await postV2ResetUserRateLimitUsage({
user_id: rateLimitData.user_id,
reset_weekly: resetWeekly,
});
if (response.status !== 200) {
throw new Error("Failed to reset usage");
}
setRateLimitData(response.data);
toast({
title: "Success",
description: resetWeekly
? "Daily and weekly usage reset to zero."
: "Daily usage reset to zero.",
});
} catch (error) {
console.error("Error resetting rate limit:", error);
toast({
title: "Error",
description: "Failed to reset rate limit usage.",
variant: "destructive",
});
}
}
return (
<>
<Button
size="small"
variant="outline"
onClick={(e) => {
e.stopPropagation();
setOpen(true);
}}
>
<Gauge size={16} className="mr-1" />
Rate Limits
</Button>
<Dialog open={open} onOpenChange={setOpen}>
<DialogContent className="sm:max-w-lg">
<DialogHeader>
<DialogTitle>Rate Limits</DialogTitle>
<DialogDescription>
CoPilot rate limits for {userEmail || userId}
</DialogDescription>
</DialogHeader>
{isLoading && (
<div className="py-8 text-center text-gray-500">
Loading rate limits...
</div>
)}
{!isLoading && rateLimitData && (
<RateLimitDisplay
data={rateLimitData}
onReset={handleReset}
className="space-y-4"
/>
)}
{!isLoading && !rateLimitData && (
<div className="py-8 text-center text-gray-500">
No rate limit data available for this user.
</div>
)}
</DialogContent>
</Dialog>
</>
);
}

View File

@@ -2,9 +2,6 @@
import { useState, useEffect } from "react";
import { useRouter, usePathname, useSearchParams } from "next/navigation";
import { Input } from "@/components/__legacy__/ui/input";
import { Button } from "@/components/__legacy__/ui/button";
import { Search } from "lucide-react";
import { CreditTransactionType } from "@/lib/autogpt-server-api";
import {
Select,
@@ -13,11 +10,11 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/__legacy__/ui/select";
import { AdminUserSearch } from "../../components/AdminUserSearch";
export function SearchAndFilterAdminSpending({
initialSearch,
}: {
initialStatus?: CreditTransactionType;
initialSearch?: string;
}) {
const router = useRouter();
@@ -37,11 +34,11 @@ export function SearchAndFilterAdminSpending({
setSearchQuery(searchParams.get("search") || "");
}, [searchParams]);
const handleSearch = () => {
function handleSearch(query: string) {
const params = new URLSearchParams(searchParams.toString());
if (searchQuery) {
params.set("search", searchQuery);
if (query) {
params.set("search", query);
} else {
params.delete("search");
}
@@ -55,21 +52,15 @@ export function SearchAndFilterAdminSpending({
params.set("page", "1"); // Reset to first page on new search
router.push(`${pathname}?${params.toString()}`);
};
}
return (
<div className="flex items-center justify-between">
<div className="flex w-full items-center gap-2">
<Input
placeholder="Search users by Name or Email..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSearch()}
/>
<Button variant="outline" onClick={handleSearch}>
<Search className="h-4 w-4" />
</Button>
</div>
<AdminUserSearch
value={searchQuery}
onChange={setSearchQuery}
onSearch={handleSearch}
/>
<Select
value={selectedStatus}

View File

@@ -1442,15 +1442,27 @@
"get": {
"tags": ["v2", "admin", "copilot", "admin"],
"summary": "Get User Rate Limit",
"description": "Get a user's current usage and effective rate limits. Admin-only.",
"description": "Get a user's current usage and effective rate limits. Admin-only.\n\nAccepts either ``user_id`` or ``email`` as a query parameter.\nWhen ``email`` is provided the user is looked up by email first.",
"operationId": "getV2Get user rate limit",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "user_id",
"in": "query",
"required": true,
"schema": { "type": "string", "title": "User Id" }
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
}
},
{
"name": "email",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Email"
}
}
],
"responses": {
@@ -14699,6 +14711,10 @@
"UserRateLimitResponse": {
"properties": {
"user_id": { "type": "string", "title": "User Id" },
"user_email": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Email"
},
"daily_token_limit": {
"type": "integer",
"title": "Daily Token Limit"