mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
54 Commits
master
...
feat/platf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a6c42e357 | ||
|
|
35a05a4b14 | ||
|
|
13e1980a95 | ||
|
|
3bf072e683 | ||
|
|
d3ac5ee3b2 | ||
|
|
b4fdba3cd4 | ||
|
|
43feb42eca | ||
|
|
f7feb02f14 | ||
|
|
36ad35dc54 | ||
|
|
5326959815 | ||
|
|
620c49facb | ||
|
|
96ca4398c5 | ||
|
|
8aeb782ccb | ||
|
|
3806f58e6b | ||
|
|
93345ff840 | ||
|
|
c1d6689215 | ||
|
|
652c12768e | ||
|
|
7a7912aed3 | ||
|
|
7ffcd704c9 | ||
|
|
9eaa903978 | ||
|
|
77ebcfe55d | ||
|
|
db029f0b49 | ||
|
|
a268291585 | ||
|
|
9caca6d899 | ||
|
|
481f704317 | ||
|
|
781224f8d5 | ||
|
|
8a91bd8a99 | ||
|
|
7bf10c605a | ||
|
|
7778de1d7b | ||
|
|
015a626379 | ||
|
|
be6501f10e | ||
|
|
32f6ef0a45 | ||
|
|
c7d5c1c844 | ||
|
|
17e78ca382 | ||
|
|
7ba05366ed | ||
|
|
ca74f980c1 | ||
|
|
68f5d2ad08 | ||
|
|
2b3d730ca9 | ||
|
|
f28628e34b | ||
|
|
b6a027fd2b | ||
|
|
fb74fcf4a4 | ||
|
|
28b26dde94 | ||
|
|
d677978c90 | ||
|
|
a347c274b7 | ||
|
|
f79d8f0449 | ||
|
|
1bc48c55d5 | ||
|
|
9d0a31c0f1 | ||
|
|
9b086e39c6 | ||
|
|
5867e4d613 | ||
|
|
f871717f68 | ||
|
|
f08e52dc86 | ||
|
|
500b345b3b | ||
|
|
995dd1b5f3 | ||
|
|
336114f217 |
106
.claude/skills/open-pr/SKILL.md
Normal file
106
.claude/skills/open-pr/SKILL.md
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
195
.claude/skills/setup-repo/SKILL.md
Normal file
195
.claude/skills/setup-repo/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
reset_user_usage,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["copilot", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
user_id: Optional[str], email: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Resolve a user_id and email from the provided parameters.
|
||||
|
||||
Returns (user_id, email). Accepts either user_id or email; at least one
|
||||
must be provided. When both are provided, ``email`` takes precedence.
|
||||
"""
|
||||
if email:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No user found with the provided email."
|
||||
)
|
||||
return user.id, email
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either user_id or email query parameter is required.",
|
||||
)
|
||||
|
||||
# We have a user_id; try to look up their email for display purposes.
|
||||
# This is non-critical -- a failure should not block the response.
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
return user_id, resolved_email
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Get User Rate Limit",
|
||||
)
|
||||
async def get_user_rate_limit(
|
||||
user_id: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Get a user's current usage and effective rate limits. Admin-only.
|
||||
|
||||
Accepts either ``user_id`` or ``email`` as a query parameter.
|
||||
When ``email`` is provided the user is looked up by email first.
|
||||
"""
|
||||
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/reset",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Reset User Rate Limit Usage",
|
||||
)
|
||||
async def reset_user_rate_limit(
|
||||
user_id: str = Body(embed=True),
|
||||
reset_weekly: bool = Body(False, embed=True),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
|
||||
logger.info(
|
||||
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
|
||||
admin_user_id,
|
||||
user_id,
|
||||
reset_weekly,
|
||||
)
|
||||
|
||||
try:
|
||||
await reset_user_usage(user_id, reset_weekly=reset_weekly)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(rate_limit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
|
||||
|
||||
_TARGET_EMAIL = "target@example.com"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_usage_status(
|
||||
daily_used: int = 500_000, weekly_used: int = 3_000_000
|
||||
) -> CoPilotUsageStatus:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _patch_rate_limit_deps(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
daily_used: int = 500_000,
|
||||
weekly_used: int = 3_000_000,
|
||||
):
|
||||
"""Patch the common rate-limit + user-lookup dependencies."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting rate limit and usage for a user."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"get_rate_limit",
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test looking up rate limits via email instead of user_id."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test that looking up a non-existent email returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_rate_limit_no_params() -> None:
|
||||
"""Test that omitting both user_id and email returns 400."""
|
||||
response = client.get("/admin/rate_limit")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting only daily usage (default behaviour)."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_only",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting both daily and weekly usage."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id, "reset_weekly": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_and_weekly",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that Redis failure on reset returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Redis connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_rate_limit_email_lookup_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that failing to resolve a user email degrades gracefully."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection lost"),
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] is None
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that rate limit admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -30,8 +30,14 @@ from backend.copilot.model import (
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
increment_daily_reset_count,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
@@ -59,9 +65,16 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -69,8 +82,6 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -421,11 +432,187 @@ async def get_copilot_usage(
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
"""Response from resetting the daily rate limit."""
|
||||
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/reset",
|
||||
status_code=200,
|
||||
responses={
|
||||
400: {
|
||||
"description": "Bad Request (feature disabled or daily limit not reached)"
|
||||
},
|
||||
402: {"description": "Payment Required (insufficient credits)"},
|
||||
429: {
|
||||
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
|
||||
},
|
||||
503: {
|
||||
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
|
||||
},
|
||||
},
|
||||
)
|
||||
async def reset_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
"""
|
||||
cost = config.rate_limit_reset_cost
|
||||
if cost <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available.",
|
||||
)
|
||||
|
||||
if not settings.config.enable_credit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No daily limit is configured — nothing to reset.",
|
||||
)
|
||||
|
||||
# Check max daily resets. get_daily_reset_count returns None when Redis
|
||||
# is unavailable; reject the reset in that case to prevent unlimited
|
||||
# free resets when the counter store is down.
|
||||
reset_count = await get_daily_reset_count(user_id)
|
||||
if reset_count is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Unable to verify reset eligibility — please try again later.",
|
||||
)
|
||||
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You've used all {config.max_daily_resets} resets for today.",
|
||||
)
|
||||
|
||||
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
|
||||
if not await acquire_reset_lock(user_id):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="A reset is already in progress. Please try again.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="You have not reached your daily limit yet.",
|
||||
)
|
||||
|
||||
# If the weekly limit is also exhausted, resetting the daily counter
|
||||
# won't help — the user would still be blocked by the weekly limit.
|
||||
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
|
||||
)
|
||||
|
||||
# Charge credits.
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
try:
|
||||
remaining = await credit_model.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
reason="CoPilot daily rate limit reset",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient credits to reset your rate limit.",
|
||||
) from e
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
await credit_model.top_up_credits(user_id, cost)
|
||||
refunded = True
|
||||
logger.warning(
|
||||
"Refunded %d credits to user %s after Redis reset failure",
|
||||
cost,
|
||||
user_id[:8],
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"CRITICAL: Failed to refund %d credits to user %s "
|
||||
"after Redis reset failure — manual intervention required",
|
||||
cost,
|
||||
user_id[:8],
|
||||
exc_info=True,
|
||||
)
|
||||
if refunded:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed — please try again later. "
|
||||
"Your credits have not been charged.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed and the automatic refund "
|
||||
"also failed. Please contact support for assistance.",
|
||||
)
|
||||
|
||||
# Track the reset count for daily cap enforcement.
|
||||
await increment_daily_reset_count(user_id)
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -526,12 +713,16 @@ async def stream_chat_post(
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -894,6 +1085,47 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedTheme(BaseModel):
|
||||
"""A themed group of suggested prompts."""
|
||||
|
||||
name: str
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts grouped by theme."""
|
||||
|
||||
themes: list[SuggestedTheme]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts grouped by theme.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns empty themes list if no custom
|
||||
prompts are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None or not understanding.suggested_prompts:
|
||||
return SuggestedPromptsResponse(themes=[])
|
||||
|
||||
themes = [
|
||||
SuggestedTheme(name=name, prompts=prompts)
|
||||
for name, prompts in understanding.suggested_prompts.items()
|
||||
]
|
||||
return SuggestedPromptsResponse(themes=themes)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -368,6 +368,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -380,6 +381,7 @@ def test_usage_uses_config_limits(
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
@@ -388,6 +390,7 @@ def test_usage_uses_config_limits(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
)
|
||||
|
||||
|
||||
@@ -400,3 +403,69 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_themes(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with themed prompts gets them back as themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {
|
||||
"Learn": ["L1", "L2"],
|
||||
"Create": ["C1"],
|
||||
}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "themes" in data
|
||||
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
|
||||
assert themes_by_name["Learn"] == ["L1", "L2"]
|
||||
assert themes_by_name["Create"] == ["C1"]
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty themes list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but empty prompts gets empty themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
@@ -17,8 +17,6 @@ from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .db import get_library_agent_by_graph_id, update_library_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,28 +59,17 @@ async def add_graph_to_library(
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent."""
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
):
|
||||
return existing
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent.
|
||||
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
Uses a create-then-catch-UniqueViolationError-then-update pattern on
|
||||
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
|
||||
This is more robust than ``upsert`` because Prisma's upsert atomicity
|
||||
guarantees are not well-documented for all versions.
|
||||
"""
|
||||
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
|
||||
_include = library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
)
|
||||
if deleted_agent and (deleted_agent.isDeleted or deleted_agent.isArchived):
|
||||
return await update_library_agent(
|
||||
deleted_agent.id,
|
||||
user_id,
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
@@ -98,23 +85,32 @@ async def add_graph_to_library(
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(graph_model).model_dump()
|
||||
),
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
include=_include,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Race condition: concurrent request created the row between our
|
||||
# check and create. Re-read instead of crashing.
|
||||
existing = await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
# Already exists — update to restore if previously soft-deleted/archived
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
raise # Shouldn't happen, but don't swallow unexpected errors
|
||||
if added_agent is None:
|
||||
raise NotFoundError(
|
||||
f"LibraryAgent for graph #{graph_model.id} "
|
||||
f"v{graph_model.version} not found after UniqueViolationError"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
|
||||
@@ -1,71 +1,80 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.errors
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_archived_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
archived_agent = MagicMock(id="library-agent-id", isDeleted=False, isArchived=True)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"""When no matching LibraryAgent exists, create inserts a new one."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
created_agent = MagicMock(name="CreatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=archived_agent)
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
assert create_data["User"] == {"connect": {"id": "user-id"}}
|
||||
assert create_data["AgentGraph"] == {
|
||||
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
|
||||
}
|
||||
assert create_data["isCreatedByUser"] is False
|
||||
assert create_data["useGraphIsActiveVersion"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_deleted_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
deleted_agent = MagicMock(id="library-agent-id", isDeleted=True, isArchived=False)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"""UniqueViolationError on create falls back to update."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
updated_agent = MagicMock(name="UpdatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=deleted_agent)
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
MagicMock(), message="unique constraint"
|
||||
)
|
||||
)
|
||||
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "user-id",
|
||||
"agentGraphId": "graph-id",
|
||||
"agentGraphVersion": 2,
|
||||
}
|
||||
}
|
||||
update_data = update_call.kwargs["data"]
|
||||
assert update_data["isDeleted"] is False
|
||||
assert update_data["isArchived"] is False
|
||||
|
||||
@@ -436,32 +436,53 @@ async def create_library_agent(
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
prisma.models.LibraryAgent.prisma(tx).upsert(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_entry.id,
|
||||
"agentGraphVersion": graph_entry.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
"update": {
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"useGraphIsActiveVersion": True,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
@@ -143,13 +141,11 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
@@ -175,37 +171,27 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
assert create_call_args is not None
|
||||
|
||||
# Verify the main structure
|
||||
expected_data = {
|
||||
# Verify the create data structure
|
||||
create_data = create_call_args.kwargs["data"]
|
||||
expected_create = {
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
}
|
||||
|
||||
actual_data = create_call_args[1]["data"]
|
||||
# Check that all expected fields are present
|
||||
for key, value in expected_data.items():
|
||||
assert actual_data[key] == value
|
||||
for key, value in expected_create.items():
|
||||
assert create_data[key] == value
|
||||
|
||||
# Check that settings field is present and is a SafeJson object
|
||||
assert "settings" in actual_data
|
||||
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
|
||||
assert "settings" in create_data
|
||||
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
|
||||
|
||||
# Check include parameter
|
||||
assert create_call_args[1]["include"] == library_agent_include(
|
||||
assert create_call_args.kwargs["include"] == library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
@@ -320,3 +306,50 @@ async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_library_agent_uses_upsert():
|
||||
"""create_library_agent should use upsert (not create) to handle duplicates."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.version = 1
|
||||
mock_graph.user_id = "user-1"
|
||||
mock_graph.nodes = []
|
||||
mock_graph.sub_graphs = []
|
||||
|
||||
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_tx():
|
||||
yield None
|
||||
|
||||
with (
|
||||
patch("backend.api.features.library.db.transaction", fake_tx),
|
||||
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library.db.add_generated_agent_image",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
|
||||
|
||||
result = await db.create_library_agent(mock_graph, "user-1")
|
||||
|
||||
assert len(result) == 1
|
||||
upsert_call = mock_prisma.return_value.upsert.call_args
|
||||
assert upsert_call is not None
|
||||
# Verify the upsert where clause uses the composite unique key
|
||||
where = upsert_call.kwargs["where"]
|
||||
assert "userId_agentGraphId_agentGraphVersion" in where
|
||||
# Verify the upsert data has both create and update branches
|
||||
data = upsert_call.kwargs["data"]
|
||||
assert "create" in data
|
||||
assert "update" in data
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Platform bot linking API
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Bot API key authentication for platform linking endpoints."""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
async def get_bot_api_key(request: Request) -> str | None:
|
||||
"""Extract the bot API key from the X-Bot-API-Key header."""
|
||||
return request.headers.get("x-bot-api-key")
|
||||
|
||||
|
||||
def check_bot_api_key(api_key: str | None) -> None:
|
||||
"""Validate the bot API key. Uses constant-time comparison.
|
||||
|
||||
Reads the key from env on each call so rotated secrets take effect
|
||||
without restarting the process.
|
||||
"""
|
||||
configured_key = os.getenv("PLATFORM_BOT_API_KEY", "")
|
||||
|
||||
if not configured_key:
|
||||
settings = Settings()
|
||||
if settings.config.enable_auth:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Bot API key not configured.",
|
||||
)
|
||||
# Auth disabled (local dev) — allow without key
|
||||
return
|
||||
|
||||
if not api_key or not hmac.compare_digest(api_key, configured_key):
|
||||
raise HTTPException(status_code=401, detail="Invalid bot API key.")
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Bot Chat Proxy endpoints.
|
||||
|
||||
Allows the bot service to send messages to CoPilot on behalf of
|
||||
linked users, authenticated via bot API key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import PlatformLink
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
|
||||
from .auth import check_bot_api_key, get_bot_api_key
|
||||
from .models import BotChatRequest, BotChatSessionResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/session",
|
||||
response_model=BotChatSessionResponse,
|
||||
summary="Create a CoPilot session for a linked user (bot-facing)",
|
||||
)
|
||||
async def bot_create_session(
|
||||
request: BotChatRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> BotChatSessionResponse:
|
||||
"""Creates a new CoPilot chat session on behalf of a linked user."""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link = await PlatformLink.prisma().find_first(where={"userId": request.user_id})
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="User has no platform links.")
|
||||
|
||||
session = await create_chat_session(request.user_id)
|
||||
|
||||
return BotChatSessionResponse(session_id=session.session_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/stream",
|
||||
summary="Stream a CoPilot response for a linked user (bot-facing)",
|
||||
)
|
||||
async def bot_chat_stream(
|
||||
request: BotChatRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
):
|
||||
"""
|
||||
Send a message to CoPilot on behalf of a linked user and stream
|
||||
the response back as Server-Sent Events.
|
||||
|
||||
The bot authenticates with its API key — no user JWT needed.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
user_id = request.user_id
|
||||
|
||||
# Verify user has a platform link
|
||||
link = await PlatformLink.prisma().find_first(where={"userId": user_id})
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="User has no platform links.")
|
||||
|
||||
# Get or create session
|
||||
session_id = request.session_id
|
||||
if session_id:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found.")
|
||||
else:
|
||||
session = await create_chat_session(user_id)
|
||||
session_id = session.session_id
|
||||
|
||||
# Save user message
|
||||
message = ChatMessage(role="user", content=request.message)
|
||||
await append_and_save_message(session_id, message)
|
||||
|
||||
# Create a turn and enqueue
|
||||
turn_id = str(uuid4())
|
||||
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bot chat: user ...%s, session %s, turn %s",
|
||||
user_id[-8:],
|
||||
session_id,
|
||||
turn_id,
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
subscriber_queue = None
|
||||
try:
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish) or (
|
||||
isinstance(chunk, str) and "[DONE]" in chunk
|
||||
):
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
except Exception:
|
||||
logger.exception("Bot chat stream error for session %s", session_id)
|
||||
yield 'data: {"type": "error", "content": "Stream error"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
if subscriber_queue is not None:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Session-Id": session_id,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Pydantic models for the platform bot linking API."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Platform(str, Enum):
|
||||
"""Supported platform types (mirrors Prisma PlatformType)."""
|
||||
|
||||
DISCORD = "DISCORD"
|
||||
TELEGRAM = "TELEGRAM"
|
||||
SLACK = "SLACK"
|
||||
TEAMS = "TEAMS"
|
||||
WHATSAPP = "WHATSAPP"
|
||||
GITHUB = "GITHUB"
|
||||
LINEAR = "LINEAR"
|
||||
|
||||
|
||||
# ── Request Models ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CreateLinkTokenRequest(BaseModel):
|
||||
"""Request from the bot service to create a linking token."""
|
||||
|
||||
platform: Platform = Field(description="Platform name")
|
||||
platform_user_id: str = Field(
|
||||
description="The user's ID on the platform",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_username: str | None = Field(
|
||||
default=None,
|
||||
description="Display name (best effort)",
|
||||
max_length=255,
|
||||
)
|
||||
channel_id: str | None = Field(
|
||||
default=None,
|
||||
description="Channel ID for sending confirmation back",
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class ResolveRequest(BaseModel):
|
||||
"""Resolve a platform identity to an AutoGPT user."""
|
||||
|
||||
platform: Platform
|
||||
platform_user_id: str = Field(min_length=1, max_length=255)
|
||||
|
||||
|
||||
class BotChatRequest(BaseModel):
|
||||
"""Request from the bot to chat as a linked user."""
|
||||
|
||||
user_id: str = Field(description="The linked AutoGPT user ID")
|
||||
message: str = Field(
|
||||
description="The user's message", min_length=1, max_length=32000
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Existing chat session ID. If omitted, a new session is created.",
|
||||
)
|
||||
|
||||
|
||||
# ── Response Models ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class LinkTokenResponse(BaseModel):
|
||||
token: str
|
||||
expires_at: datetime
|
||||
link_url: str
|
||||
|
||||
|
||||
class LinkTokenStatusResponse(BaseModel):
|
||||
status: Literal["pending", "linked", "expired"]
|
||||
user_id: str | None = None
|
||||
|
||||
|
||||
class ResolveResponse(BaseModel):
|
||||
linked: bool
|
||||
user_id: str | None = None
|
||||
|
||||
|
||||
class PlatformLinkInfo(BaseModel):
|
||||
id: str
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
platform_username: str | None
|
||||
linked_at: datetime
|
||||
|
||||
|
||||
class ConfirmLinkResponse(BaseModel):
|
||||
success: bool
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
platform_username: str | None
|
||||
|
||||
|
||||
class DeleteLinkResponse(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class BotChatSessionResponse(BaseModel):
|
||||
"""Returned when creating a new session via the bot proxy."""
|
||||
|
||||
session_id: str
|
||||
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Platform Bot Linking API routes.
|
||||
|
||||
Enables linking external chat platform identities (Discord, Telegram, Slack, etc.)
|
||||
to AutoGPT user accounts. Used by the multi-platform CoPilot bot.
|
||||
|
||||
Flow:
|
||||
1. Bot calls POST /api/platform-linking/tokens to create a link token
|
||||
for an unlinked platform user.
|
||||
2. Bot sends the user a link: {frontend}/link/{token}
|
||||
3. User clicks the link, logs in to AutoGPT, and the frontend calls
|
||||
POST /api/platform-linking/tokens/{token}/confirm to complete the link.
|
||||
4. Bot can poll GET /api/platform-linking/tokens/{token}/status or just
|
||||
check on next message via GET /api/platform-linking/resolve.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Security
|
||||
from prisma.models import PlatformLink, PlatformLinkToken
|
||||
|
||||
from .auth import check_bot_api_key, get_bot_api_key
|
||||
from .models import (
|
||||
ConfirmLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenResponse,
|
||||
LinkTokenStatusResponse,
|
||||
PlatformLinkInfo,
|
||||
ResolveRequest,
|
||||
ResolveResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
LINK_TOKEN_EXPIRY_MINUTES = 30
|
||||
|
||||
# Path parameter with validation for link tokens
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
# ── Bot-facing endpoints (API key auth) ───────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens",
|
||||
response_model=LinkTokenResponse,
|
||||
summary="Create a link token for an unlinked platform user",
|
||||
)
|
||||
async def create_link_token(
|
||||
request: CreateLinkTokenRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> LinkTokenResponse:
|
||||
"""
|
||||
Called by the bot service when it encounters an unlinked user.
|
||||
Generates a one-time token the user can use to link their account.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
platform = request.platform.value
|
||||
|
||||
# Check if already linked
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This platform account is already linked.",
|
||||
)
|
||||
|
||||
# Invalidate any existing pending tokens for this user
|
||||
await PlatformLinkToken.prisma().update_many(
|
||||
where={
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Generate token
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=LINK_TOKEN_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await PlatformLinkToken.prisma().create(
|
||||
data={
|
||||
"token": token,
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"platformUsername": request.platform_username,
|
||||
"channelId": request.channel_id,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created link token for %s (expires %s)",
|
||||
platform,
|
||||
expires_at.isoformat(),
|
||||
)
|
||||
|
||||
link_base_url = os.getenv(
|
||||
"PLATFORM_LINK_BASE_URL", "https://platform.agpt.co/link"
|
||||
)
|
||||
link_url = f"{link_base_url}/{token}"
|
||||
|
||||
return LinkTokenResponse(
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
link_url=link_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/status",
|
||||
response_model=LinkTokenStatusResponse,
|
||||
summary="Check if a link token has been consumed",
|
||||
)
|
||||
async def get_link_token_status(
|
||||
token: TokenPath,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> LinkTokenStatusResponse:
|
||||
"""
|
||||
Called by the bot service to check if a user has completed linking.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
# Token was used — find the linked account
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
return LinkTokenStatusResponse(
|
||||
status="linked",
|
||||
user_id=link.userId if link else None,
|
||||
)
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
return LinkTokenStatusResponse(status="expired")
|
||||
|
||||
return LinkTokenStatusResponse(status="pending")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/resolve",
|
||||
response_model=ResolveResponse,
|
||||
summary="Resolve a platform identity to an AutoGPT user",
|
||||
)
|
||||
async def resolve_platform_user(
|
||||
request: ResolveRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> ResolveResponse:
|
||||
"""
|
||||
Called by the bot service on every incoming message to check if
|
||||
the platform user has a linked AutoGPT account.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": request.platform.value,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
|
||||
if not link:
|
||||
return ResolveResponse(linked=False)
|
||||
|
||||
return ResolveResponse(linked=True, user_id=link.userId)
|
||||
|
||||
|
||||
# ── User-facing endpoints (JWT auth) ──────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
"""
|
||||
Called by the frontend when the user clicks the link and is logged in.
|
||||
Consumes the token and creates the platform link.
|
||||
Uses atomic update_many to prevent race conditions on double-click.
|
||||
"""
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found.")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
raise HTTPException(status_code=410, detail="This link has already been used.")
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=410, detail="This link has expired.")
|
||||
|
||||
# Atomically mark token as used (only if still unused)
|
||||
updated = await PlatformLinkToken.prisma().update_many(
|
||||
where={"token": token, "usedAt": None},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
if updated == 0:
|
||||
raise HTTPException(status_code=410, detail="This link has already been used.")
|
||||
|
||||
# Check if this platform identity is already linked
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
detail = (
|
||||
"This platform account is already linked to your account."
|
||||
if existing.userId == user_id
|
||||
else "This platform account is already linked to another user."
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail)
|
||||
|
||||
# Create the link — catch unique constraint race condition
|
||||
try:
|
||||
await PlatformLink.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
"platformUsername": link_token.platformUsername,
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
if "unique" in str(exc).lower():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This platform account was just linked by another request.",
|
||||
) from exc
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Linked %s:%s to user ...%s",
|
||||
link_token.platform,
|
||||
link_token.platformUserId,
|
||||
user_id[-8:],
|
||||
)
|
||||
|
||||
return ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform=link_token.platform,
|
||||
platform_user_id=link_token.platformUserId,
|
||||
platform_username=link_token.platformUsername,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform links for the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
"""Returns all platform identities linked to the current user's account."""
|
||||
links = await PlatformLink.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"linkedAt": "desc"},
|
||||
)
|
||||
|
||||
return [
|
||||
PlatformLinkInfo(
|
||||
id=link.id,
|
||||
platform=link.platform,
|
||||
platform_user_id=link.platformUserId,
|
||||
platform_username=link.platformUsername,
|
||||
linked_at=link.linkedAt,
|
||||
)
|
||||
for link in links
|
||||
]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform identity",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
"""
|
||||
Removes a platform link. The user will need to re-link if they
|
||||
want to use the bot on that platform again.
|
||||
"""
|
||||
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
|
||||
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="Link not found.")
|
||||
|
||||
if link.userId != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not your link.")
|
||||
|
||||
await PlatformLink.prisma().delete(where={"id": link_id})
|
||||
|
||||
logger.info(
|
||||
"Unlinked %s:%s from user ...%s",
|
||||
link.platform,
|
||||
link.platformUserId,
|
||||
user_id[-8:],
|
||||
)
|
||||
|
||||
return DeleteLinkResponse(success=True)
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Tests for platform bot linking API routes."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.platform_linking.auth import check_bot_api_key
|
||||
from backend.api.features.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenStatusResponse,
|
||||
Platform,
|
||||
ResolveRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformEnum:
|
||||
def test_all_platforms_exist(self):
|
||||
assert Platform.DISCORD.value == "DISCORD"
|
||||
assert Platform.TELEGRAM.value == "TELEGRAM"
|
||||
assert Platform.SLACK.value == "SLACK"
|
||||
assert Platform.TEAMS.value == "TEAMS"
|
||||
assert Platform.WHATSAPP.value == "WHATSAPP"
|
||||
assert Platform.GITHUB.value == "GITHUB"
|
||||
assert Platform.LINEAR.value == "LINEAR"
|
||||
|
||||
|
||||
class TestBotApiKeyAuth:
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
|
||||
@patch("backend.api.features.platform_linking.auth.Settings")
|
||||
def test_no_key_configured_allows_when_auth_disabled(self, mock_settings_cls):
|
||||
mock_settings_cls.return_value.config.enable_auth = False
|
||||
check_bot_api_key(None)
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
|
||||
@patch("backend.api.features.platform_linking.auth.Settings")
|
||||
def test_no_key_configured_rejects_when_auth_enabled(self, mock_settings_cls):
|
||||
mock_settings_cls.return_value.config.enable_auth = True
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key(None)
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_valid_key(self):
|
||||
check_bot_api_key("secret123")
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_invalid_key_rejected(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key("wrong")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_missing_key_rejected(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestCreateLinkTokenRequest:
|
||||
def test_valid_request(self):
|
||||
req = CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="353922987235213313",
|
||||
)
|
||||
assert req.platform == Platform.DISCORD
|
||||
assert req.platform_user_id == "353922987235213313"
|
||||
|
||||
def test_empty_platform_user_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="",
|
||||
)
|
||||
|
||||
def test_too_long_platform_user_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="x" * 256,
|
||||
)
|
||||
|
||||
def test_invalid_platform_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest.model_validate(
|
||||
{"platform": "INVALID", "platform_user_id": "123"}
|
||||
)
|
||||
|
||||
|
||||
class TestResolveRequest:
|
||||
def test_valid_request(self):
|
||||
req = ResolveRequest(
|
||||
platform=Platform.TELEGRAM,
|
||||
platform_user_id="123456789",
|
||||
)
|
||||
assert req.platform == Platform.TELEGRAM
|
||||
|
||||
def test_empty_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ResolveRequest(
|
||||
platform=Platform.SLACK,
|
||||
platform_user_id="",
|
||||
)
|
||||
|
||||
|
||||
class TestResponseModels:
|
||||
def test_link_token_status_literal(self):
|
||||
resp = LinkTokenStatusResponse(status="pending")
|
||||
assert resp.status == "pending"
|
||||
|
||||
resp = LinkTokenStatusResponse(status="linked", user_id="abc")
|
||||
assert resp.status == "linked"
|
||||
|
||||
resp = LinkTokenStatusResponse(status="expired")
|
||||
assert resp.status == "expired"
|
||||
|
||||
def test_confirm_link_response(self):
|
||||
resp = ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform="DISCORD",
|
||||
platform_user_id="123",
|
||||
platform_username="testuser",
|
||||
)
|
||||
assert resp.success is True
|
||||
|
||||
def test_delete_link_response(self):
|
||||
resp = DeleteLinkResponse(success=True)
|
||||
assert resp.success is True
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -29,6 +30,8 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.chat_proxy
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
@@ -318,6 +321,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.rate_limit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -355,6 +363,16 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.chat_proxy.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"continuous_read": False,
|
||||
|
||||
@@ -28,9 +28,9 @@ class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
It takes in a value, name, and description.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
It outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -47,12 +47,6 @@ class AgentInputBlock(Block):
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
@@ -65,10 +59,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
return copy.deepcopy(self.get_field_schema("value"))
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
@@ -86,18 +77,16 @@ class AgentInputBlock(Block):
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"value": 42,
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"description": "Example numeric input.",
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
("result", 42),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
@@ -245,13 +234,11 @@ class AgentShortTextInputBlock(AgentInputBlock):
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -285,13 +272,11 @@ class AgentLongTextInputBlock(AgentInputBlock):
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -325,13 +310,11 @@ class AgentNumberInputBlock(AgentInputBlock):
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -501,6 +484,12 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
|
||||
@@ -104,6 +104,18 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
if isinstance(value, str) and "/" in value:
|
||||
stripped = value.split("/", 1)[1]
|
||||
try:
|
||||
return cls(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
|
||||
@@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks._base import Block, BlockSchemaInput
|
||||
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
@@ -279,3 +281,66 @@ class TestAutoCredentialsFieldsValidation:
|
||||
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
|
||||
for backward compatibility with existing agent JSON."""
|
||||
legacy_data = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
instance = AgentInputBlock.Input.model_construct(**legacy_data)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
"enum" not in schema
|
||||
), "AgentInputBlock should not produce enum from legacy placeholder_values"
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with legacy placeholder_values
|
||||
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
|
||||
legacy_input_default = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentInputBlock.Input, legacy_input_default),
|
||||
)
|
||||
url_props = result["properties"]["url"]
|
||||
assert (
|
||||
"enum" not in url_props
|
||||
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
@@ -207,6 +207,51 @@ class TestXMLParserBlockSecurity:
|
||||
pass
|
||||
|
||||
|
||||
class TestXMLParserBlockSyntaxErrors:
|
||||
"""XML syntax errors should raise ValueError (not SyntaxError).
|
||||
|
||||
This ensures the base Block.execute() wraps them as BlockExecutionError
|
||||
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
|
||||
Sentry).
|
||||
"""
|
||||
|
||||
async def test_unclosed_tag_raises_value_error(self):
|
||||
"""Unclosed tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "<root><unclosed>"
|
||||
|
||||
with pytest.raises(ValueError, match="Unclosed tag"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_unexpected_closing_tag_raises_value_error(self):
|
||||
"""Extra closing tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "</unexpected>"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_empty_xml_raises_value_error(self):
|
||||
"""Empty XML input should raise ValueError."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
with pytest.raises(ValueError, match="XML input is empty"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
|
||||
pass
|
||||
|
||||
async def test_syntax_error_from_parser_becomes_value_error(self):
|
||||
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
|
||||
block = XMLParserBlock()
|
||||
# Malformed XML that might trigger a SyntaxError from the parser
|
||||
bad_xml = "<root><child>no closing"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
|
||||
@@ -809,3 +809,33 @@ class TestUserErrorStatusCodeHandling:
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
|
||||
class TestLlmModelMissing:
|
||||
"""Test that LlmModel handles provider-prefixed model names."""
|
||||
|
||||
def test_provider_prefixed_model_resolves(self):
|
||||
"""Provider-prefixed model string should resolve to the correct enum member."""
|
||||
assert (
|
||||
llm.LlmModel("anthropic/claude-sonnet-4-6")
|
||||
== llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
)
|
||||
|
||||
def test_bare_model_still_works(self):
|
||||
"""Bare (non-prefixed) model string should still resolve correctly."""
|
||||
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
|
||||
def test_invalid_prefixed_model_raises(self):
|
||||
"""Unknown provider-prefixed model string should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
llm.LlmModel("invalid/nonexistent-model")
|
||||
|
||||
def test_slash_containing_value_direct_lookup(self):
|
||||
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
|
||||
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
|
||||
def test_double_prefixed_slash_model(self):
|
||||
"""Double-prefixed value should still resolve by stripping first prefix."""
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
raise ValueError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
raise ValueError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
@@ -76,4 +76,7 @@ class XMLParserBlock(Block):
|
||||
except ValueError as val_e:
|
||||
raise ValueError(f"Validation error for dict:{val_e}") from val_e
|
||||
except SyntaxError as syn_e:
|
||||
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
# Raise as ValueError so the base Block.execute() wraps it as
|
||||
# BlockExecutionError (expected user-caused failure) instead of
|
||||
# BlockUnknownError (unexpected platform error that alerts Sentry).
|
||||
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
|
||||
@@ -91,6 +91,20 @@ class ChatConfig(BaseSettings):
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Cost (in credits / cents) to reset the daily rate limit using credits.
|
||||
# When a user hits their daily limit, they can spend this amount to reset
|
||||
# the daily counter and keep working. Set to 0 to disable the feature.
|
||||
rate_limit_reset_cost: int = Field(
|
||||
default=500,
|
||||
ge=0,
|
||||
description="Credit cost (in cents) for resetting the daily rate limit. 0 = disabled.",
|
||||
)
|
||||
max_daily_resets: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
description="Maximum number of credit-based rate limit resets per user per day. 0 = unlimited.",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
|
||||
@@ -18,7 +18,7 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
if msg.get("duration_ms") is not None:
|
||||
data["durationMs"] = msg["duration_ms"]
|
||||
|
||||
messages_data.append(data)
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
@@ -359,3 +362,22 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
order={"sequence": "desc"},
|
||||
)
|
||||
if last_msg:
|
||||
await PrismaChatMessage.prisma().update(
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -205,9 +205,10 @@ Important files (code, configs, outputs) should be saved to workspace to ensure
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
@@ -61,6 +65,7 @@ async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -68,6 +73,7 @@ async def get_usage_status(
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -97,6 +103,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +148,110 @@ async def check_rate_limit(
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
"""Reset a user's daily token usage counter in Redis.
|
||||
|
||||
Called after a user pays credits to extend their daily limit.
|
||||
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
|
||||
(clamped to 0) so the user effectively gets one extra day's worth of
|
||||
weekly capacity.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
|
||||
# (weekly) either both execute or neither does. This prevents the
|
||||
# scenario where the daily counter is cleared but the weekly
|
||||
# counter is not decremented — which would let the caller refund
|
||||
# credits even though the daily limit was already reset.
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
|
||||
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.delete(d_key)
|
||||
if w_key is not None:
|
||||
pipe.decrby(w_key, daily_token_limit)
|
||||
results = await pipe.execute()
|
||||
|
||||
# Clamp negative weekly counter to 0 (best-effort; not critical).
|
||||
if w_key is not None:
|
||||
new_val = results[1] # DECRBY result
|
||||
if new_val < 0:
|
||||
await redis.set(w_key, 0, keepttl=True)
|
||||
|
||||
logger.info("Reset daily usage for user %s", user_id[:8])
|
||||
return True
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for resetting daily usage")
|
||||
return False
|
||||
|
||||
|
||||
_RESET_LOCK_PREFIX = "copilot:reset_lock"
|
||||
_RESET_COUNT_PREFIX = "copilot:reset_count"
|
||||
|
||||
|
||||
async def acquire_reset_lock(user_id: str, ttl_seconds: int = 10) -> bool:
|
||||
"""Acquire a short-lived lock to serialize rate limit resets per user."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_LOCK_PREFIX}:{user_id}"
|
||||
return bool(await redis.set(key, "1", nx=True, ex=ttl_seconds))
|
||||
except (RedisError, ConnectionError, OSError) as exc:
|
||||
logger.warning("Redis unavailable for reset lock, rejecting reset: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
async def release_reset_lock(user_id: str) -> None:
|
||||
"""Release the per-user reset lock."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(f"{_RESET_LOCK_PREFIX}:{user_id}")
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
pass # Lock will expire via TTL
|
||||
|
||||
|
||||
async def get_daily_reset_count(user_id: str) -> int | None:
|
||||
"""Get how many times the user has reset today.
|
||||
|
||||
Returns None when Redis is unavailable so callers can fail-closed
|
||||
for billed operations (as opposed to failing open for read-only
|
||||
rate-limit checks).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
val = await redis.get(key)
|
||||
return int(val or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for reading daily reset count")
|
||||
return None
|
||||
|
||||
|
||||
async def increment_daily_reset_count(user_id: str) -> None:
|
||||
"""Increment and track how many resets this user has done today."""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.incr(key)
|
||||
seconds_until_reset = int((_daily_reset_time(now=now) - now).total_seconds())
|
||||
pipe.expire(key, max(seconds_until_reset, 1))
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for tracking reset count")
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
@@ -231,6 +342,67 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
daily_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
|
||||
)
|
||||
weekly_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
|
||||
)
|
||||
try:
|
||||
daily = max(0, int(daily_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
|
||||
daily = config_daily
|
||||
try:
|
||||
weekly = max(0, int(weekly_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
"""Reset a user's usage counters.
|
||||
|
||||
Always deletes the daily Redis key. When *reset_weekly* is ``True``,
|
||||
the weekly key is deleted as well.
|
||||
|
||||
Unlike read paths (``get_usage_status``, ``check_rate_limit``) which
|
||||
fail-open on Redis errors, resets intentionally re-raise so the caller
|
||||
knows the operation did not succeed. A silent failure here would leave
|
||||
the admin believing the counters were zeroed when they were not.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
keys_to_delete = [_daily_key(user_id, now=now)]
|
||||
if reset_weekly:
|
||||
keys_to_delete.append(_weekly_key(user_id, now=now))
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(*keys_to_delete)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for resetting user usage")
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -12,6 +12,7 @@ from .rate_limit import (
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
reset_daily_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
@@ -332,3 +333,91 @@ class TestRecordTokenUsage:
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_daily_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetDailyUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
|
||||
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[1, decrby_result])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_daily_key(self):
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=0)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is True
|
||||
mock_pipe.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reduces_weekly_usage_via_decrby(self):
|
||||
"""Weekly counter should be reduced via DECRBY in the pipeline."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clamps_negative_weekly_to_zero(self):
|
||||
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_weekly_reduction_when_daily_limit_zero(self):
|
||||
"""When daily_token_limit is 0, weekly counter should not be touched."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=0)
|
||||
|
||||
mock_pipe.delete.assert_called_once()
|
||||
mock_pipe.decrby.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is False
|
||||
|
||||
294
autogpt_platform/backend/backend/copilot/reset_usage_test.py
Normal file
294
autogpt_platform/backend/backend/copilot/reset_usage_test.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Unit tests for the POST /usage/reset endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
# Minimal config mock matching ChatConfig fields used by the endpoint.
|
||||
def _make_config(
|
||||
rate_limit_reset_cost: int = 500,
|
||||
daily_token_limit: int = 2_500_000,
|
||||
weekly_token_limit: int = 12_500_000,
|
||||
max_daily_resets: int = 5,
|
||||
):
|
||||
cfg = MagicMock()
|
||||
cfg.rate_limit_reset_cost = rate_limit_reset_cost
|
||||
cfg.daily_token_limit = daily_token_limit
|
||||
cfg.weekly_token_limit = weekly_token_limit
|
||||
cfg.max_daily_resets = max_daily_resets
|
||||
return cfg
|
||||
|
||||
|
||||
def _usage(daily_used: int = 3_000_000, daily_limit: int = 2_500_000):
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_limit,
|
||||
resets_at=datetime.now(UTC) + timedelta(hours=6),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=5_000_000,
|
||||
limit=12_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(days=3),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_MODULE = "backend.api.features.chat.routes"
|
||||
|
||||
|
||||
def _mock_settings(enable_credit: bool = True):
|
||||
"""Return a mock Settings object with the given enable_credit flag."""
|
||||
mock = MagicMock()
|
||||
mock.config.enable_credit = enable_credit
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
"""When rate_limit_reset_cost=0, endpoint returns 400."""
|
||||
|
||||
with patch(f"{_MODULE}.config", _make_config(rate_limit_reset_cost=0)):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "not available" in exc_info.value.detail
|
||||
|
||||
async def test_no_daily_limit_returns_400(self):
|
||||
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "nothing to reset" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_not_at_limit_returns_400(self):
|
||||
"""When user hasn't hit their daily limit, returns 400."""
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage(daily_used=1_000_000)),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "not reached" in exc_info.value.detail
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_insufficient_credits_returns_402(self):
|
||||
"""When user doesn't have enough credits, returns 402."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.side_effect = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id="user-1",
|
||||
balance=50,
|
||||
amount=200,
|
||||
)
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 402
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_happy_path(self):
|
||||
"""Successful reset: charges credits, resets usage, returns response."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500 # remaining balance
|
||||
|
||||
cfg = _make_config()
|
||||
updated_usage = _usage(daily_used=0)
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(side_effect=[_usage(), updated_usage]),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=True)
|
||||
) as mock_reset,
|
||||
patch(f"{_MODULE}.increment_daily_reset_count", AsyncMock()) as mock_incr,
|
||||
):
|
||||
result = await reset_copilot_usage(user_id="user-1")
|
||||
assert result.success is True
|
||||
assert result.credits_charged == 500
|
||||
assert result.remaining_balance == 1500
|
||||
mock_reset.assert_awaited_once()
|
||||
mock_incr.assert_awaited_once()
|
||||
|
||||
async def test_max_daily_resets_exceeded(self):
|
||||
"""When user has exhausted daily resets, returns 429."""
|
||||
|
||||
cfg = _make_config(max_daily_resets=3)
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
async def test_credit_system_disabled_returns_400(self):
|
||||
"""When enable_credit=False, endpoint returns 400."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings(enable_credit=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "credit system is disabled" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_weekly_limit_exhausted_returns_400(self):
|
||||
"""When the weekly limit is also exhausted, resetting daily won't help."""
|
||||
|
||||
cfg = _make_config()
|
||||
weekly_exhausted = CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=3_000_000,
|
||||
limit=2_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(hours=6),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=12_500_000,
|
||||
limit=12_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(days=3),
|
||||
),
|
||||
)
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=weekly_exhausted),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "weekly" in exc_info.value.detail.lower()
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_redis_failure_for_reset_count_returns_503(self):
|
||||
"""When Redis is unavailable for get_daily_reset_count, returns 503."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "verify" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_redis_reset_failure_refunds_credits(self):
|
||||
"""When reset_daily_usage fails, credits are refunded and 503 returned."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "not been charged" in exc_info.value.detail
|
||||
mock_credit_model.top_up_credits.assert_awaited_once()
|
||||
|
||||
async def test_redis_reset_failure_refund_also_fails(self):
|
||||
"""When both reset and refund fail, error message reflects the truth."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500
|
||||
mock_credit_model.top_up_credits.side_effect = RuntimeError("db down")
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "contact support" in exc_info.value.detail.lower()
|
||||
@@ -67,9 +67,17 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Optional: `title`, `description`
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
- For dropdown/select inputs, use **AgentDropdownInputBlock** instead (see below)
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
|
||||
@@ -25,24 +25,64 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
|
||||
Delegates to ``build_structured_transcript`` — plain content strings
|
||||
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
|
||||
assistant messages.
|
||||
"""
|
||||
# Cast widening: tuple[str, str] is structurally compatible with
|
||||
# tuple[str, str | list[dict]] but list invariance requires explicit
|
||||
# annotation.
|
||||
widened: list[tuple[str, str | list[dict]]] = list(pairs)
|
||||
return build_structured_transcript(widened)
|
||||
|
||||
|
||||
def build_structured_transcript(
|
||||
entries: list[tuple[str, str | list[dict]]],
|
||||
) -> str:
|
||||
"""Build a JSONL transcript with structured content blocks.
|
||||
|
||||
Each entry is (role, content) where content is either a plain string
|
||||
(for user messages) or a list of content block dicts (for assistant
|
||||
messages with thinking/tool_use/text blocks).
|
||||
|
||||
Example::
|
||||
|
||||
build_structured_transcript([
|
||||
("user", "Hello"),
|
||||
("assistant", [
|
||||
{"type": "thinking", "thinking": "...", "signature": "sig1"},
|
||||
{"type": "text", "text": "Hi there"},
|
||||
]),
|
||||
])
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
for role, content in entries:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
msg: dict = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": content,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
elif role == "assistant":
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
msg = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
|
||||
@@ -442,8 +442,11 @@ class TestCompactTranscript:
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
# The last assistant entry is preserved verbatim from original
|
||||
assert msgs[2]["content"] == "Details"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
|
||||
@@ -15,6 +15,7 @@ from claude_agent_sdk import (
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
@@ -100,6 +101,11 @@ class SDKResponseAdapter:
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
# Thinking blocks are preserved in the transcript but
|
||||
# not streamed to the frontend — skip silently.
|
||||
pass
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
|
||||
@@ -124,8 +124,11 @@ class TestScenarioCompactAndRetry:
|
||||
assert result != original # Must be different
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[0]["content"] == "[summary of conversation]"
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Long answer 2"
|
||||
|
||||
def test_compacted_transcript_loads_into_builder(self):
|
||||
"""TranscriptBuilder can load a compacted transcript and continue."""
|
||||
@@ -737,7 +740,10 @@ class TestRetryEdgeCases:
|
||||
assert result is not None
|
||||
assert result != transcript
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Answer 19"
|
||||
|
||||
def test_messages_to_transcript_roundtrip_preserves_content(self):
|
||||
"""Verify messages → transcript → messages preserves all content."""
|
||||
|
||||
@@ -185,6 +185,24 @@ def _is_prompt_too_long(err: BaseException) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
|
||||
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
|
||||
|
||||
Two known patterns occur when ``GeneratorExit`` tears down the async
|
||||
generator and the SDK's ``__aexit__`` runs in a different context/task:
|
||||
|
||||
* ``RuntimeError``: cancel scope exited in wrong task (anyio)
|
||||
* ``ValueError``: ContextVar token created in a different Context (OTEL)
|
||||
|
||||
These are suppressed to avoid polluting Sentry with non-actionable noise.
|
||||
"""
|
||||
if isinstance(exc, RuntimeError) and "cancel scope" in str(exc):
|
||||
return True
|
||||
if isinstance(exc, ValueError) and "was created in a different Context" in str(exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_tool_only_message(sdk_msg: object) -> bool:
|
||||
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
|
||||
|
||||
@@ -409,6 +427,63 @@ _HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
|
||||
async def _safe_close_sdk_client(
|
||||
sdk_client: ClaudeSDKClient,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Close a ClaudeSDKClient, suppressing errors from client disconnect.
|
||||
|
||||
When the SSE client disconnects mid-stream, ``GeneratorExit`` propagates
|
||||
through the async generator stack and causes ``ClaudeSDKClient.__aexit__``
|
||||
to run in a different async context or task than where the client was
|
||||
opened. This triggers two known error classes:
|
||||
|
||||
* ``ValueError``: ``<Token var=<ContextVar name='current_context'>>
|
||||
was created in a different Context`` — OpenTelemetry's
|
||||
``context.detach()`` fails because the OTEL context token was
|
||||
created in the original generator coroutine but detach runs in
|
||||
the GC / cleanup coroutine (Sentry: AUTOGPT-SERVER-8BT).
|
||||
|
||||
* ``RuntimeError``: ``Attempted to exit cancel scope in a different
|
||||
task than it was entered in`` — anyio's ``TaskGroup.__aexit__``
|
||||
detects that the cancel scope was entered in one task but is
|
||||
being exited in another (Sentry: AUTOGPT-SERVER-8BW).
|
||||
|
||||
Both are harmless — the TCP connection is already dead and no
|
||||
resources leak. Logging them at ``debug`` level keeps observability
|
||||
without polluting Sentry.
|
||||
"""
|
||||
try:
|
||||
await sdk_client.__aexit__(None, None, None)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
if _is_sdk_disconnect_error(exc):
|
||||
# Expected during client disconnect — suppress to avoid Sentry noise.
|
||||
logger.debug(
|
||||
"%s SDK client cleanup error suppressed (client disconnect): %s: %s",
|
||||
log_prefix,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# GeneratorExit can propagate through __aexit__ — suppress it here
|
||||
# since the generator is already being torn down.
|
||||
logger.debug(
|
||||
"%s SDK client cleanup GeneratorExit suppressed (client disconnect)",
|
||||
log_prefix,
|
||||
)
|
||||
except Exception:
|
||||
# Unexpected cleanup error — log at error level so Sentry captures it
|
||||
# (via its logging integration), but don't propagate since we're in
|
||||
# teardown and the caller cannot meaningfully handle this.
|
||||
logger.error(
|
||||
"%s Unexpected SDK client cleanup error",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _iter_sdk_messages(
|
||||
client: ClaudeSDKClient,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
@@ -595,7 +670,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
Unknown block types are logged and skipped.
|
||||
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
|
||||
a typed class for) are passed through verbatim to preserve them in the
|
||||
transcript. Unknown typed block objects are logged and skipped.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
@@ -627,6 +704,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, dict) and "type" in block:
|
||||
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
|
||||
result.append(block)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
@@ -1188,7 +1268,17 @@ async def _run_stream_attempt(
|
||||
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
async with ClaudeSDKClient(options=state.options) as client:
|
||||
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
|
||||
# suppress SDK cleanup errors that occur when the SSE client disconnects
|
||||
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
|
||||
# different async context/task than where the client was opened, which
|
||||
# triggers:
|
||||
# - ValueError: ContextVar token mismatch (AUTOGPT-SERVER-8BT)
|
||||
# - RuntimeError: cancel scope in wrong task (AUTOGPT-SERVER-8BW)
|
||||
# Both are harmless — the TCP connection is already dead.
|
||||
sdk_client = ClaudeSDKClient(options=state.options)
|
||||
client = await sdk_client.__aenter__()
|
||||
try:
|
||||
logger.info(
|
||||
"%s Sending query — resume=%s, total_msgs=%d, "
|
||||
"query_len=%d, attached_files=%d, image_blocks=%d",
|
||||
@@ -1448,6 +1538,8 @@ async def _run_stream_attempt(
|
||||
|
||||
if acc.stream_completed:
|
||||
break
|
||||
finally:
|
||||
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
|
||||
|
||||
# --- Post-stream processing (only on success) ---
|
||||
if state.adapter.has_unresolved_tool_calls:
|
||||
@@ -2169,9 +2261,16 @@ async def stream_chat_completion_sdk(
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
# SDK cleanup RuntimeError is expected during cancellation, log as warning
|
||||
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
|
||||
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
|
||||
# SDK cleanup errors are expected during client disconnect —
|
||||
# log as warning rather than error to reduce Sentry noise.
|
||||
# These are normally caught by _safe_close_sdk_client but
|
||||
# can escape in edge cases (e.g. GeneratorExit timing).
|
||||
if _is_sdk_disconnect_error(e):
|
||||
logger.warning(
|
||||
"%s SDK cleanup error (client disconnect): %s",
|
||||
log_prefix,
|
||||
error_msg,
|
||||
)
|
||||
else:
|
||||
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
|
||||
|
||||
@@ -2193,10 +2292,11 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Yield StreamError for immediate feedback (only for non-cancellation errors)
|
||||
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
|
||||
is_cancellation = isinstance(e, asyncio.CancelledError) or (
|
||||
isinstance(e, RuntimeError) and "cancel scope" in str(e)
|
||||
)
|
||||
# Skip for CancelledError and SDK disconnect cleanup errors — these
|
||||
# are not actionable by the user and the SSE connection is already dead.
|
||||
is_cancellation = isinstance(
|
||||
e, asyncio.CancelledError
|
||||
) or _is_sdk_disconnect_error(e)
|
||||
if not is_cancellation:
|
||||
yield StreamError(errorText=display_msg, code=code)
|
||||
|
||||
|
||||
@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
from .service import (
|
||||
_is_sdk_disconnect_error,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_safe_close_sdk_client,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -499,3 +504,111 @@ class TestResolveSdkModel:
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsSdkDisconnectError:
|
||||
"""Tests for _is_sdk_disconnect_error — identifies expected SDK cleanup errors."""
|
||||
|
||||
def test_cancel_scope_runtime_error(self):
|
||||
"""RuntimeError about cancel scope in wrong task is a disconnect error."""
|
||||
exc = RuntimeError(
|
||||
"Attempted to exit cancel scope in a different task than it was entered in"
|
||||
)
|
||||
assert _is_sdk_disconnect_error(exc) is True
|
||||
|
||||
def test_context_var_value_error(self):
|
||||
"""ValueError about ContextVar token mismatch is a disconnect error."""
|
||||
exc = ValueError(
|
||||
"<Token var=<ContextVar name='current_context'>> "
|
||||
"was created in a different Context"
|
||||
)
|
||||
assert _is_sdk_disconnect_error(exc) is True
|
||||
|
||||
def test_unrelated_runtime_error(self):
|
||||
"""Unrelated RuntimeError should NOT be classified as disconnect error."""
|
||||
exc = RuntimeError("something else went wrong")
|
||||
assert _is_sdk_disconnect_error(exc) is False
|
||||
|
||||
def test_unrelated_value_error(self):
|
||||
"""Unrelated ValueError should NOT be classified as disconnect error."""
|
||||
exc = ValueError("invalid argument")
|
||||
assert _is_sdk_disconnect_error(exc) is False
|
||||
|
||||
def test_other_exception_types(self):
|
||||
"""Non-RuntimeError/ValueError should NOT be classified as disconnect error."""
|
||||
assert _is_sdk_disconnect_error(TypeError("bad type")) is False
|
||||
assert _is_sdk_disconnect_error(OSError("network down")) is False
|
||||
assert _is_sdk_disconnect_error(asyncio.CancelledError()) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _safe_close_sdk_client — suppress cleanup errors during disconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSafeCloseSdkClient:
|
||||
"""Tests for _safe_close_sdk_client — suppresses expected SDK cleanup errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clean_exit(self):
|
||||
"""Normal __aexit__ (no error) should succeed silently."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(return_value=None)
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
client.__aexit__.assert_awaited_once_with(None, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_scope_runtime_error_suppressed(self):
|
||||
"""RuntimeError from cancel scope mismatch should be suppressed."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(
|
||||
side_effect=RuntimeError(
|
||||
"Attempted to exit cancel scope in a different task"
|
||||
)
|
||||
)
|
||||
# Should NOT raise
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_var_value_error_suppressed(self):
|
||||
"""ValueError from ContextVar token mismatch should be suppressed."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"<Token var=<ContextVar name='current_context'>> "
|
||||
"was created in a different Context"
|
||||
)
|
||||
)
|
||||
# Should NOT raise
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_suppressed_with_error_log(self):
|
||||
"""Unexpected exceptions should be caught (not propagated) but logged at error."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=OSError("unexpected"))
|
||||
# Should NOT raise — unexpected errors are also suppressed to
|
||||
# avoid crashing the generator during teardown. Logged at error
|
||||
# level so Sentry captures them via its logging integration.
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrelated_runtime_error_propagates(self):
|
||||
"""Non-cancel-scope RuntimeError should propagate (not suppressed)."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=RuntimeError("something unrelated"))
|
||||
with pytest.raises(RuntimeError, match="something unrelated"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrelated_value_error_propagates(self):
|
||||
"""Non-disconnect ValueError should propagate (not suppressed)."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
|
||||
with pytest.raises(ValueError, match="invalid argument"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@@ -0,0 +1,822 @@
|
||||
"""Tests for thinking/redacted_thinking block preservation.
|
||||
|
||||
Validates the fix for the Anthropic API error:
|
||||
"thinking or redacted_thinking blocks in the latest assistant message
|
||||
cannot be modified. These blocks must remain as they were in the
|
||||
original response."
|
||||
|
||||
The API requires that thinking blocks in the LAST assistant message are
|
||||
preserved value-identical. Older assistant messages may have thinking blocks
|
||||
stripped entirely. This test suite covers:
|
||||
|
||||
1. _flatten_assistant_content — strips thinking from older messages
|
||||
2. compact_transcript — preserves last assistant's thinking blocks
|
||||
3. response_adapter — handles ThinkingBlock without error
|
||||
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THINKING_BLOCK = {
|
||||
"type": "thinking",
|
||||
"thinking": "Let me analyze the user's request carefully...",
|
||||
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
|
||||
}
|
||||
|
||||
REDACTED_THINKING_BLOCK = {
|
||||
"type": "redacted_thinking",
|
||||
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
|
||||
}
|
||||
|
||||
|
||||
def _make_thinking_transcript() -> str:
|
||||
"""Build a transcript with thinking blocks in multiple assistant turns.
|
||||
|
||||
Layout:
|
||||
User 1 → Assistant 1 (thinking + text + tool_use)
|
||||
User 2 (tool_result) → Assistant 2 (thinking + text)
|
||||
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
|
||||
"""
|
||||
return build_structured_transcript(
|
||||
[
|
||||
("user", "What files are in this project?"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I should list the files.",
|
||||
"signature": "sig_old_1",
|
||||
},
|
||||
{"type": "text", "text": "Let me check the files."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "list_files",
|
||||
"input": {"path": "/"},
|
||||
},
|
||||
],
|
||||
),
|
||||
("user", "Here are the files: a.py, b.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Good, I see two Python files.",
|
||||
"signature": "sig_old_2",
|
||||
},
|
||||
{"type": "text", "text": "I found a.py and b.py."},
|
||||
],
|
||||
),
|
||||
("user", "Tell me about a.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "a.py contains the main entry point."},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
|
||||
"""Extract the content blocks of the last assistant entry in a transcript."""
|
||||
last_content = None
|
||||
for line in transcript_jsonl.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_content = msg.get("content")
|
||||
return last_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_last_assistant_entry — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindLastAssistantEntry:
|
||||
def test_splits_at_last_assistant(self):
|
||||
"""Prefix contains everything before last assistant; tail starts at it."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
|
||||
assert len(prefix) == 3
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_no_assistant_returns_all_in_prefix(self):
|
||||
"""When there's no assistant, all lines are in prefix, tail is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("user", "Hello"), ("user", "Another question")]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 2
|
||||
assert tail == []
|
||||
|
||||
def test_assistant_at_index_zero(self):
|
||||
"""When assistant is the first entry, prefix is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("assistant", [{"type": "text", "text": "Start"}])]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert prefix == []
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_trailing_user_included_in_tail(self):
|
||||
"""User message after last assistant is part of the tail."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Q1"),
|
||||
("assistant", [{"type": "text", "text": "A1"}]),
|
||||
("user", "Q2"),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # first user
|
||||
assert len(tail) == 2 # last assistant + trailing user
|
||||
|
||||
def test_multi_entry_turn_fully_preserved(self):
|
||||
"""An assistant turn spanning multiple JSONL entries (same message.id)
|
||||
must be entirely in the tail, not split across prefix and tail."""
|
||||
# Build manually because build_structured_transcript generates unique ids
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [THINKING_BLOCK],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# Both assistant entries share msg_same_turn → both in tail
|
||||
assert len(prefix) == 1 # only the user entry
|
||||
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
|
||||
|
||||
def test_no_message_id_preserves_last_assistant(self):
|
||||
"""When the last assistant entry has no message.id, it should still
|
||||
be preserved in the tail (fail closed) rather than being compressed."""
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # user entry
|
||||
assert len(tail) == 1 # assistant entry preserved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rechain_tail — UUID chain patching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRechainTail:
|
||||
def test_patches_first_entry_parentuuid(self):
|
||||
"""First tail entry's parentUuid should point to last prefix uuid."""
|
||||
prefix = _messages_to_transcript(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
)
|
||||
# Get the last uuid from the prefix
|
||||
last_prefix_uuid = None
|
||||
for line in prefix.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
last_prefix_uuid = entry.get("uuid")
|
||||
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "tail-a1",
|
||||
"parentUuid": "old-parent",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Tail msg"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["parentUuid"] == last_prefix_uuid
|
||||
assert entry["uuid"] == "tail-a1" # uuid preserved
|
||||
|
||||
def test_chains_multiple_tail_entries(self):
|
||||
"""Subsequent tail entries chain to each other."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old1",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "t2",
|
||||
"parentUuid": "old2",
|
||||
"message": {"role": "user", "content": "Follow-up"},
|
||||
}
|
||||
),
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entries = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(entries) == 2
|
||||
# Second entry's parentUuid should be first entry's uuid
|
||||
assert entries[1]["parentUuid"] == "t1"
|
||||
|
||||
def test_empty_tail_returns_empty(self):
|
||||
"""No tail entries → empty string."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
assert _rechain_tail(prefix, []) == ""
|
||||
|
||||
def test_preserves_message_content_verbatim(self):
|
||||
"""Tail message content (including thinking blocks) must not be modified."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
original_content = [
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "Response"},
|
||||
]
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": original_content,
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == original_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content — thinking blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenThinkingBlocks:
|
||||
def test_thinking_blocks_are_stripped(self):
|
||||
"""Thinking blocks should not appear in flattened text for compression."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
|
||||
{"type": "text", "text": "Hello user"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "secret thoughts" not in result
|
||||
assert "Hello user" in result
|
||||
|
||||
def test_redacted_thinking_blocks_are_stripped(self):
|
||||
"""Redacted thinking blocks should not appear in flattened text."""
|
||||
blocks = [
|
||||
{"type": "redacted_thinking", "data": "encrypted_data"},
|
||||
{"type": "text", "text": "Response text"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "encrypted_data" not in result
|
||||
assert "Response text" in result
|
||||
|
||||
def test_thinking_only_message_flattens_to_empty(self):
|
||||
"""A message with only thinking blocks flattens to empty string."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
{"type": "text", "text": "I'll read the file."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript — thinking block preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscriptThinkingBlocks:
|
||||
"""Verify that compact_transcript preserves thinking blocks in the
|
||||
last assistant message while stripping them from older messages."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
|
||||
"""After compaction, the last assistant entry must retain its
|
||||
original thinking and redacted_thinking blocks verbatim."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None, "No assistant entry found"
|
||||
assert isinstance(last_content, list)
|
||||
|
||||
# The last assistant must have the thinking blocks preserved
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block missing from last assistant message"
|
||||
assert (
|
||||
"redacted_thinking" in block_types
|
||||
), "redacted_thinking block missing from last assistant message"
|
||||
assert "text" in block_types
|
||||
|
||||
# Verify the thinking block content is value-identical
|
||||
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
|
||||
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
|
||||
|
||||
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
|
||||
assert len(redacted_blocks) == 1
|
||||
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
|
||||
"""Older assistant messages should NOT retain thinking blocks
|
||||
after compaction (they're compressed into summaries)."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
# The compressor will receive messages where older assistant
|
||||
# entries have already had thinking blocks stripped.
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def mock_compression(messages, model, log_prefix):
|
||||
captured_messages.extend(messages)
|
||||
return type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": messages,
|
||||
"original_token_count": 800,
|
||||
"token_count": 400,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
|
||||
# Check that the messages sent to compression don't contain
|
||||
# thinking content from older assistant messages
|
||||
for msg in captured_messages:
|
||||
if msg["role"] == "assistant":
|
||||
content = msg.get("content", "")
|
||||
assert (
|
||||
"I should list the files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
assert (
|
||||
"Good, I see two Python files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
|
||||
"""When the last entry is a user message, the last *assistant*
|
||||
message's thinking blocks should still be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "Hi there"},
|
||||
],
|
||||
),
|
||||
("user", "Follow-up question"),
|
||||
]
|
||||
)
|
||||
|
||||
# The compressor only receives the prefix (1 user message); the
|
||||
# tail (assistant + trailing user) is preserved verbatim.
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 400,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block lost from last assistant despite trailing user msg"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
|
||||
"""When there's only one assistant message (which is also the last),
|
||||
its thinking blocks must be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "World"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "World"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert "thinking" in block_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
|
||||
"""After compaction, the first tail entry's parentUuid must point to
|
||||
the last entry in the compressed prefix — not its original parent."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
|
||||
entries = [json.loads(ln) for ln in lines]
|
||||
|
||||
# Find the boundary: the compressed prefix ends just before the
|
||||
# first tail entry (last assistant in original transcript).
|
||||
tail_start = None
|
||||
for i, entry in enumerate(entries):
|
||||
msg = entry.get("message", {})
|
||||
if isinstance(msg.get("content"), list):
|
||||
# Structured content = preserved tail entry
|
||||
tail_start = i
|
||||
break
|
||||
|
||||
assert tail_start is not None, "Could not find preserved tail entry"
|
||||
assert tail_start > 0, "Tail should not be the first entry"
|
||||
|
||||
# The tail entry's parentUuid must be the uuid of the preceding entry
|
||||
prefix_last_uuid = entries[tail_start - 1]["uuid"]
|
||||
tail_first_parent = entries[tail_start]["parentUuid"]
|
||||
assert tail_first_parent == prefix_last_uuid, (
|
||||
f"Tail parentUuid {tail_first_parent!r} != "
|
||||
f"last prefix uuid {prefix_last_uuid!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
|
||||
"""Compaction should still work normally when there are no thinking
|
||||
blocks in the transcript."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summary"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 50,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
# Verify last assistant content is preserved even without thinking blocks
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert last_content == [{"type": "text", "text": "Details"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages — thinking block handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscriptToMessagesThinking:
|
||||
def test_thinking_blocks_excluded_from_flattened_content(self):
|
||||
"""When _transcript_to_messages flattens content, thinking block
|
||||
text should not leak into the message content string."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "SECRET_THOUGHT",
|
||||
"signature": "sig",
|
||||
},
|
||||
{"type": "text", "text": "Visible response"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
|
||||
assert "SECRET_THOUGHT" not in assistant_msg["content"]
|
||||
assert "Visible response" in assistant_msg["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# response_adapter — ThinkingBlock handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseAdapterThinkingBlock:
|
||||
def test_thinking_block_does_not_crash(self):
|
||||
"""ThinkingBlock in AssistantMessage should not cause an error."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="Let me think about this...",
|
||||
signature="sig_test_123",
|
||||
),
|
||||
TextBlock(text="Here is my response."),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# Should produce stream events for text only, no crash
|
||||
types = [type(r) for r in results]
|
||||
assert StreamStartStep in types
|
||||
assert StreamTextStart in types or StreamTextDelta in types
|
||||
|
||||
def test_thinking_block_does_not_emit_stream_events(self):
|
||||
"""ThinkingBlock should NOT produce any StreamTextDelta events
|
||||
containing thinking content."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="My secret thoughts",
|
||||
signature="sig_test_456",
|
||||
),
|
||||
TextBlock(text="Public response"),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
for delta in text_deltas:
|
||||
assert "secret thoughts" not in (delta.delta or "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_sdk_content_blocks — redacted_thinking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatSdkContentBlocks:
|
||||
def test_thinking_block_preserved(self):
|
||||
"""ThinkingBlock should be serialized with type, thinking, and signature."""
|
||||
blocks = [
|
||||
ThinkingBlock(thinking="My thoughts", signature="sig123"),
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == {
|
||||
"type": "thinking",
|
||||
"thinking": "My thoughts",
|
||||
"signature": "sig123",
|
||||
}
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
|
||||
def test_raw_dict_redacted_thinking_preserved(self):
|
||||
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
|
||||
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
|
||||
blocks = [
|
||||
raw_block,
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == raw_block
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
@@ -605,20 +605,31 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
|
||||
silently dropped — they carry no useful context for compression
|
||||
summaries and must not leak into compacted transcripts (the Anthropic
|
||||
API requires thinking blocks in the last assistant message to be
|
||||
value-identical to the original response; including stale thinking
|
||||
text would violate that constraint).
|
||||
|
||||
This is intentional: ``compress_context`` requires plain text for
|
||||
token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript
|
||||
was already too large for the model.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype in _THINKING_BLOCK_TYPES:
|
||||
continue
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
@@ -805,6 +816,68 @@ async def _run_compression(
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_entry(
|
||||
content: str,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Split JSONL lines into (compressible_prefix, preserved_tail).
|
||||
|
||||
The tail starts at the **first** entry of the last assistant turn and
|
||||
includes everything after it (typically trailing user messages). An
|
||||
assistant turn can span multiple consecutive JSONL entries sharing the
|
||||
same ``message.id`` (e.g., a thinking entry followed by a tool_use
|
||||
entry). All entries of the turn are preserved verbatim.
|
||||
|
||||
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
|
||||
blocks in the **last** assistant message remain value-identical to the
|
||||
original response (the API validates parsed signature values, not raw
|
||||
JSON bytes). By excluding the entire turn from compression we
|
||||
guarantee those blocks are never altered.
|
||||
|
||||
Returns ``(all_lines, [])`` when no assistant entry is found.
|
||||
"""
|
||||
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
|
||||
|
||||
# Parse all lines once to avoid double JSON deserialization.
|
||||
# json.loads with fallback=None returns Any; non-dict entries are
|
||||
# safely skipped by the isinstance(entry, dict) guards below.
|
||||
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
|
||||
|
||||
# Reverse scan: find the message.id and index of the last assistant entry.
|
||||
last_asst_msg_id: str | None = None
|
||||
last_asst_idx: int | None = None
|
||||
for i in range(len(parsed) - 1, -1, -1):
|
||||
entry = parsed[i]
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_asst_idx = i
|
||||
last_asst_msg_id = msg.get("id")
|
||||
break
|
||||
|
||||
if last_asst_idx is None:
|
||||
return lines, []
|
||||
|
||||
# If the assistant entry has no message.id, fall back to preserving
|
||||
# from that single entry onward — safer than compressing everything.
|
||||
if last_asst_msg_id is None:
|
||||
return lines[:last_asst_idx], lines[last_asst_idx:]
|
||||
|
||||
# Forward scan: find the first entry of this turn (same message.id).
|
||||
first_turn_idx: int | None = None
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
|
||||
first_turn_idx = i
|
||||
break
|
||||
|
||||
if first_turn_idx is None:
|
||||
return lines, []
|
||||
return lines[:first_turn_idx], lines[first_turn_idx:]
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
@@ -816,42 +889,50 @@ async def compact_transcript(
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
The **last assistant entry** (and any entries after it) are preserved
|
||||
verbatim — never flattened or compressed. The Anthropic API requires
|
||||
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
|
||||
message to be value-identical to the original response (the API
|
||||
validates parsed signature values, not raw JSON bytes); compressing
|
||||
them would destroy the cryptographic signatures and cause
|
||||
``invalid_request_error``.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
Structured content in *older* assistant entries (``tool_use`` blocks,
|
||||
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
|
||||
plain text for compression. This matches the fidelity of the Plan C
|
||||
(DB compression) fallback path.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
lists for pre-query DB history.
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
prefix_lines, tail_lines = _find_last_assistant_entry(content)
|
||||
|
||||
# Build the JSONL string for the compressible prefix
|
||||
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
|
||||
messages = _transcript_to_messages(prefix_content) if prefix_content else []
|
||||
|
||||
if len(messages) + len(tail_lines) < 2:
|
||||
total = len(messages) + len(tail_lines)
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
|
||||
return None
|
||||
if not messages:
|
||||
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
if not result.messages:
|
||||
logger.warning("%s Compressor returned empty messages", log_prefix)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
@@ -860,7 +941,29 @@ async def compact_transcript(
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
compressed_part = _messages_to_transcript(result.messages)
|
||||
|
||||
# Re-append the preserved tail (last assistant + trailing entries)
|
||||
# with parentUuid patched to chain onto the compressed prefix.
|
||||
tail_part = _rechain_tail(compressed_part, tail_lines)
|
||||
compacted = compressed_part + tail_part
|
||||
|
||||
if len(compacted) >= len(content):
|
||||
# Byte count can increase due to preserved tail entries
|
||||
# (thinking blocks, JSON overhead) even when token count
|
||||
# decreased. Log a warning but still return — the API
|
||||
# validates tokens not bytes, and the caller falls through
|
||||
# to DB fallback if the transcript is still too large.
|
||||
logger.warning(
|
||||
"%s Compacted transcript (%d bytes) is not smaller than "
|
||||
"original (%d bytes) — may still reduce token count",
|
||||
log_prefix,
|
||||
len(compacted),
|
||||
len(content),
|
||||
)
|
||||
# Authoritative validation — the caller (_reduce_context) also
|
||||
# validates, but this is the canonical check that guarantees we
|
||||
# never return a malformed transcript from this function.
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
@@ -870,3 +973,43 @@ async def compact_transcript(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
|
||||
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
|
||||
|
||||
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
|
||||
last entry in the compressed prefix. Subsequent tail entries are
|
||||
rechained to point to their predecessor in the tail — their original
|
||||
``parentUuid`` values may reference entries that were compressed away.
|
||||
"""
|
||||
if not tail_lines:
|
||||
return ""
|
||||
# Find the last uuid in the compressed prefix
|
||||
last_prefix_uuid = ""
|
||||
for line in reversed(compressed_prefix.strip().split("\n")):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if isinstance(entry, dict) and "uuid" in entry:
|
||||
last_prefix_uuid = entry["uuid"]
|
||||
break
|
||||
|
||||
result_lines: list[str] = []
|
||||
prev_uuid: str | None = None
|
||||
for i, line in enumerate(tail_lines):
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
# Safety guard: _find_last_assistant_entry already filters empty
|
||||
# lines, and well-formed JSONL always parses to dicts. Non-dict
|
||||
# lines are passed through unchanged; prev_uuid is intentionally
|
||||
# NOT updated so the next dict entry chains to the last known uuid.
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if i == 0:
|
||||
entry["parentUuid"] = last_prefix_uuid
|
||||
elif prev_uuid is not None:
|
||||
entry["parentUuid"] = prev_uuid
|
||||
prev_uuid = entry.get("uuid")
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
@@ -26,6 +26,7 @@ import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
@@ -111,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
|
||||
pre-dates the turn_id field (backward compat for in-flight sessions).
|
||||
"""
|
||||
created_at = datetime.now(timezone.utc)
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return ActiveSession(
|
||||
session_id=meta.get("session_id", "") or session_id,
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
@@ -119,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -802,6 +812,33 @@ async def mark_session_completed(
|
||||
f"Failed to publish error event for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
# Compute wall-clock duration from session created_at.
|
||||
# Only persist when (a) the session completed successfully and
|
||||
# (b) created_at was actually present in Redis meta (not a fallback).
|
||||
duration_ms: int | None = None
|
||||
if meta and not error_message:
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
elapsed = datetime.now(timezone.utc) - created_at
|
||||
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to compute session duration for %s (created_at=%r)",
|
||||
session_id,
|
||||
created_at_raw,
|
||||
)
|
||||
|
||||
# Persist duration on the last assistant message
|
||||
if duration_ms is not None:
|
||||
try:
|
||||
await chat_db().set_turn_duration(session_id, duration_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
|
||||
|
||||
# Publish StreamFinish AFTER status is set to "completed"/"failed".
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
|
||||
@@ -102,7 +102,6 @@ async def setup_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Test input field",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
@@ -242,7 +241,6 @@ async def setup_llm_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Prompt for the LLM",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
@@ -396,7 +394,6 @@ async def setup_firecrawl_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "URL for Firecrawl to scrape",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
@@ -1536,8 +1538,8 @@ class AgentFixer:
|
||||
for link in links:
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
if DICT_SPLIT in sink_name:
|
||||
parent, child = sink_name.split(DICT_SPLIT, 1)
|
||||
|
||||
# Check if child is a numeric index (invalid for _#_ notation)
|
||||
if child.isdigit():
|
||||
|
||||
@@ -4,6 +4,8 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
|
||||
__all__ = [
|
||||
@@ -51,8 +53,8 @@ def generate_uuid() -> str:
|
||||
|
||||
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
|
||||
"""Get property type from a schema, handling nested `_#_` notation."""
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
if DICT_SPLIT in name:
|
||||
parent, child = name.split(DICT_SPLIT, 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema and isinstance(
|
||||
parent_schema["properties"], dict
|
||||
|
||||
@@ -5,6 +5,8 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
@@ -256,95 +258,6 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
Returns True if all nested links are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not sink_name or not sink_id:
|
||||
continue
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id")
|
||||
input_props = block_input_schemas.get(block_id, {})
|
||||
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' for "
|
||||
f"node '{sink_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Parent property "
|
||||
f"'{parent}' does not exist in the block's "
|
||||
f"input schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
|
||||
# Check anyOf for additionalProperties
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' "
|
||||
f"for node '{link.get('sink_id', '')}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' schema. Available "
|
||||
f"properties: "
|
||||
f"{list(parent_schema.get('properties', {}).keys())}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that prompt parameters do not contain spaces in double curly
|
||||
@@ -471,8 +384,8 @@ class AgentValidator:
|
||||
output_props = block_output_schemas.get(block_id, {})
|
||||
|
||||
# Handle nested source names (with _#_ notation)
|
||||
if "_#_" in source_name:
|
||||
parent, child = source_name.split("_#_", 1)
|
||||
if DICT_SPLIT in source_name:
|
||||
parent, child = source_name.split(DICT_SPLIT, 1)
|
||||
|
||||
parent_schema = output_props.get(parent)
|
||||
if not parent_schema:
|
||||
@@ -553,6 +466,195 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_sink_input_existence(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all sink_names in links and input_default keys in nodes
|
||||
exist in the corresponding block's input schema.
|
||||
|
||||
Checks that for each link the sink_name references a valid input
|
||||
property in the sink block's inputSchema, and that every key in a
|
||||
node's input_default is a recognised input property. Also handles
|
||||
nested inputs with _#_ notation and dynamic schemas for
|
||||
AgentExecutorBlock.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
node_lookup: Optional pre-built node-id → node dict
|
||||
|
||||
Returns:
|
||||
True if all sink input fields exist, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
def get_input_props(node: dict[str, Any]) -> dict[str, Any]:
|
||||
block_id = node.get("block_id", "")
|
||||
if block_id == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
dynamic_input_schema = input_default.get("input_schema", {})
|
||||
if not isinstance(dynamic_input_schema, dict):
|
||||
dynamic_input_schema = {}
|
||||
dynamic_props = dynamic_input_schema.get("properties", {})
|
||||
if not isinstance(dynamic_props, dict):
|
||||
dynamic_props = {}
|
||||
static_props = block_input_schemas.get(block_id, {})
|
||||
return {**static_props, **dynamic_props}
|
||||
return block_input_schemas.get(block_id, {})
|
||||
|
||||
def check_nested_input(
|
||||
input_props: dict[str, Any],
|
||||
field_name: str,
|
||||
context: str,
|
||||
block_name: str,
|
||||
block_id: str,
|
||||
) -> bool:
|
||||
parent, child = field_name.split(DICT_SPLIT, 1)
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"{context}: Parent property '{parent}' does not "
|
||||
f"exist in block '{block_name}' ({block_id}) input "
|
||||
f"schema."
|
||||
)
|
||||
return False
|
||||
|
||||
allows_additional = parent_schema.get("additionalProperties", False)
|
||||
# Only anyOf is checked here because Pydantic's JSON schema
|
||||
# emits optional/union fields via anyOf. allOf and oneOf are
|
||||
# not currently used by any block's dict-typed inputs, so
|
||||
# false positives from them are not a concern in practice.
|
||||
if not allows_additional and "anyOf" in parent_schema:
|
||||
for schema_option in parent_schema.get("anyOf", []):
|
||||
if not isinstance(schema_option, dict):
|
||||
continue
|
||||
if schema_option.get("additionalProperties"):
|
||||
allows_additional = True
|
||||
break
|
||||
items_schema = schema_option.get("items")
|
||||
if isinstance(items_schema, dict) and items_schema.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional = True
|
||||
break
|
||||
|
||||
if not allows_additional:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
available = (
|
||||
list(parent_schema.get("properties", {}).keys())
|
||||
if isinstance(parent_schema, dict)
|
||||
else []
|
||||
)
|
||||
self.add_error(
|
||||
f"{context}: Child property '{child}' does not "
|
||||
f"exist in parent '{parent}' of block "
|
||||
f"'{block_name}' ({block_id}) input schema. "
|
||||
f"Available properties: {available}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name", "")
|
||||
link_id = link.get("id", "Unknown")
|
||||
|
||||
if not sink_name:
|
||||
# Missing sink_name is caught by validate_data_type_compatibility
|
||||
continue
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
# Already caught by validate_link_node_references
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id", "")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
input_props = get_input_props(sink_node)
|
||||
|
||||
context = (
|
||||
f"Invalid sink input field '{sink_name}' in link "
|
||||
f"'{link_id}' to node '{sink_id}'"
|
||||
)
|
||||
|
||||
if DICT_SPLIT in sink_name:
|
||||
if not check_nested_input(
|
||||
input_props, sink_name, context, block_name, block_id
|
||||
):
|
||||
valid = False
|
||||
else:
|
||||
if sink_name not in input_props:
|
||||
available_inputs = list(input_props.keys())
|
||||
self.add_error(
|
||||
f"{context} (block '{block_name}' - {block_id}): "
|
||||
f"Input property '{sink_name}' does not exist in "
|
||||
f"the block's input schema. "
|
||||
f"Available inputs: {available_inputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
node_id = node.get("id")
|
||||
block_id = node.get("block_id", "")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
if not isinstance(input_default, dict) or not input_default:
|
||||
continue
|
||||
|
||||
if (
|
||||
block_id not in block_input_schemas
|
||||
and block_id != AGENT_EXECUTOR_BLOCK_ID
|
||||
):
|
||||
continue
|
||||
|
||||
input_props = get_input_props(node)
|
||||
|
||||
for key in input_default:
|
||||
if key == "credentials":
|
||||
continue
|
||||
|
||||
context = (
|
||||
f"Node '{node_id}' (block '{block_name}' - {block_id}) "
|
||||
f"has unknown input_default key '{key}'"
|
||||
)
|
||||
|
||||
if DICT_SPLIT in key:
|
||||
if not check_nested_input(
|
||||
input_props, key, context, block_name, block_id
|
||||
):
|
||||
valid = False
|
||||
else:
|
||||
if key not in input_props:
|
||||
available_inputs = list(input_props.keys())
|
||||
self.add_error(
|
||||
f"{context} which does not exist in the "
|
||||
f"block's input schema. "
|
||||
f"Available inputs: {available_inputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_io_blocks(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that the agent has at least one AgentInputBlock and one
|
||||
@@ -998,14 +1100,14 @@ class AgentValidator:
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Sink input existence",
|
||||
self.validate_sink_input_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
self.validate_prompt_double_curly_braces_spaces(agent),
|
||||
|
||||
@@ -331,43 +331,6 @@ class TestValidatePromptDoubleCurlyBracesSpaces:
|
||||
assert any("spaces" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_nested_sink_links
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateNestedSinkLinks:
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is True
|
||||
|
||||
def test_invalid_parent_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(block_id="b1")
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_block_schemas
|
||||
# ============================================================================
|
||||
@@ -595,11 +558,28 @@ class TestValidate:
|
||||
input_block = _make_block(
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
name="AgentInputBlock",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
"description": {"type": "string"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
output_schema={"properties": {"result": {}}},
|
||||
)
|
||||
output_block = _make_block(
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
name="AgentOutputBlock",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
)
|
||||
input_node = _make_node(
|
||||
node_id="n-in",
|
||||
@@ -650,6 +630,201 @@ class TestValidate:
|
||||
assert "AgentOutputBlock" in error_message
|
||||
|
||||
|
||||
class TestValidateSinkInputExistence:
|
||||
"""Tests for validate_sink_input_existence."""
|
||||
|
||||
def test_valid_sink_name_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src", source_name="out", sink_id="n1", sink_name="url"
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_sink_name_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src", source_name="out", sink_id="n1", sink_name="nonexistent"
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("nonexistent" in e for e in v.errors)
|
||||
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="config_#_key",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_nested_child_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="config_#_missing",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
|
||||
def test_unknown_input_default_key_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1", block_id="b1", input_default={"nonexistent_key": "value"}
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("nonexistent_key" in e for e in v.errors)
|
||||
|
||||
def test_credentials_key_skipped(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={
|
||||
"url": "http://example.com",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_agent_executor_dynamic_schema_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_schema={
|
||||
"properties": {
|
||||
"graph_id": {"type": "string"},
|
||||
"input_schema": {"type": "object"},
|
||||
},
|
||||
"required": ["graph_id"],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": "abc",
|
||||
"input_schema": {
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="query",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_input_default_nested_invalid_child_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={"config_#_invalid_child": "value"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("invalid_child" in e for e in v.errors)
|
||||
|
||||
def test_input_default_nested_valid_child_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={"config_#_key": "value"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
|
||||
class TestValidateMCPToolBlocks:
|
||||
"""Tests for validate_mcp_tool_blocks."""
|
||||
|
||||
|
||||
@@ -537,7 +537,7 @@ async def check_hitl_review(
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
review_context = ExecutionContext(
|
||||
@@ -582,7 +582,16 @@ def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
"""Resolve credential requirements, applying discriminator logic where needed.
|
||||
|
||||
Handles two discrimination modes:
|
||||
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
|
||||
field value selects the provider (e.g. an AI model name -> provider).
|
||||
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
|
||||
is ``None``): the discriminator field value (typically a URL) is added to
|
||||
``discriminator_values`` so that host-scoped credential matching can
|
||||
compare the credential's host against the target URL.
|
||||
"""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
@@ -592,25 +601,42 @@ def _resolve_discriminated_credentials(
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
if field_info.discriminator:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
if discriminator_value is not None:
|
||||
if field_info.discriminator_mapping:
|
||||
# Provider-based discrimination (e.g. model -> provider)
|
||||
if discriminator_value in field_info.discriminator_mapping:
|
||||
effective_field_info = field_info.discriminate(
|
||||
discriminator_value
|
||||
)
|
||||
effective_field_info.discriminator_values.add(
|
||||
discriminator_value
|
||||
)
|
||||
# Model names are safe to log (not PII); URLs are
|
||||
# intentionally omitted in the host-based branch below.
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
else:
|
||||
# URL/host-based discrimination (e.g. url -> host matching).
|
||||
# Deep copy to avoid mutating the cached schema-level
|
||||
# field_info (model_copy() is shallow — the mutable set
|
||||
# would be shared).
|
||||
effective_field_info = field_info.model_copy(deep=True)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Added discriminator value for host matching on %s",
|
||||
field_name,
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
"""Tests for credential resolution across all credential types in the CoPilot.
|
||||
|
||||
These tests verify that:
|
||||
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
|
||||
for URL-based (host-scoped) and provider-based (api_key) credential fields.
|
||||
2. `find_matching_credential` correctly matches credentials for all types:
|
||||
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
|
||||
HostScopedCredentials.
|
||||
3. The full `resolve_block_credentials` flow correctly resolves matching
|
||||
credentials or reports them as missing for each credential type.
|
||||
4. `RunBlockTool._execute` end-to-end tests return correct response types.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import SendAuthenticatedWebRequestBlock
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
|
||||
from .models import BlockDetailsResponse, SetupRequirementsResponse
|
||||
from .run_block import RunBlockTool
|
||||
from .utils import find_matching_credential
|
||||
|
||||
_TEST_USER_ID = "test-user-http-cred"
|
||||
|
||||
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
|
||||
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
|
||||
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
|
||||
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_discriminated_credentials tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveDiscriminatedCredentials:
|
||||
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
|
||||
|
||||
def _get_auth_block(self):
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
def test_url_discriminator_populates_discriminator_values(self):
|
||||
"""When input_data contains a URL, discriminator_values should include it."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert "https://api.example.com/v1/data" in field_info.discriminator_values
|
||||
|
||||
def test_url_discriminator_without_url_keeps_empty_values(self):
|
||||
"""When no URL is provided, discriminator_values should remain empty."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_does_not_mutate_original_field_info(self):
|
||||
"""The original block schema field_info must not be mutated."""
|
||||
block = self._get_auth_block()
|
||||
|
||||
# Grab a reference to the original schema-level field_info
|
||||
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
|
||||
|
||||
# Call with a URL, which adds to discriminator_values on the copy
|
||||
_resolve_discriminated_credentials(
|
||||
block, {"url": "https://api.example.com/v1/data"}
|
||||
)
|
||||
|
||||
# The original object must remain unchanged
|
||||
assert len(original_info.discriminator_values) == 0
|
||||
|
||||
# And a fresh call without URL should also return empty values
|
||||
result = _resolve_discriminated_credentials(block, {})
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_preserves_provider_and_type(self):
|
||||
"""Provider and supported_types should be preserved after URL discrimination."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
field_info = result["credentials"]
|
||||
assert ProviderName.HTTP in field_info.provider
|
||||
assert "host_scoped" in field_info.supported_types
|
||||
|
||||
def test_provider_discriminator_still_works(self):
|
||||
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
|
||||
|
||||
The refactored conditional in _resolve_discriminated_credentials split the
|
||||
original single ``if`` into nested ``if/else`` branches. This test ensures
|
||||
the provider-based path still narrows the provider correctly.
|
||||
"""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
# Should narrow provider to openai
|
||||
assert ProviderName.OPENAI in field_info.provider
|
||||
assert "gpt-4o-mini" in field_info.discriminator_values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (host-scoped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingHostScopedCredential:
|
||||
"""Tests for find_matching_credential with host-scoped credentials."""
|
||||
|
||||
def _make_host_scoped_cred(
|
||||
self, host: str, cred_id: str = "test-cred-id"
|
||||
) -> HostScopedCredentials:
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider="http",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer test-token")},
|
||||
title=f"Cred for {host}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, discriminator_values: set | None = None
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.HTTP]),
|
||||
credentials_types=_HOST_SCOPED_TYPES,
|
||||
credentials_scopes=None,
|
||||
discriminator="url",
|
||||
discriminator_values=discriminator_values or set(),
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_host(self):
|
||||
"""A host-scoped credential matching the URL host should be returned."""
|
||||
cred = self._make_host_scoped_cred("api.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_host(self):
|
||||
"""A host-scoped credential for a different host should not match."""
|
||||
cred = self._make_host_scoped_cred("api.github.com")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_any_when_no_discriminator_values(self):
|
||||
"""With empty discriminator_values, any host-scoped credential matches.
|
||||
|
||||
Note: this tests the current fallback behavior in _credential_is_for_host()
|
||||
where empty discriminator_values means "no host constraint" and any
|
||||
host-scoped credential is accepted. This is by design for the case where
|
||||
the target URL is not yet known (e.g. schema preview with empty input).
|
||||
"""
|
||||
cred = self._make_host_scoped_cred("api.anything.com")
|
||||
field_info = self._make_field_info(set())
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_wildcard_host_matching(self):
|
||||
"""Wildcard host (*.example.com) should match subdomains."""
|
||||
cred = self._make_host_scoped_cred("*.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple host-scoped credentials exist, the correct one is selected."""
|
||||
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
|
||||
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred_github, cred_stripe], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "stripe-cred"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (api_key)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingAPIKeyCredential:
|
||||
"""Tests for find_matching_credential with API key credentials."""
|
||||
|
||||
def _make_api_key_cred(
|
||||
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
api_key=SecretStr("sk-test-key-123"),
|
||||
title=f"API key for {provider}",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An API key credential matching the provider should be returned."""
|
||||
cred = self._make_api_key_cred("google_maps")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An API key credential for a different provider should not match."""
|
||||
cred = self._make_api_key_cred("openai")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An OAuth2 credential should not match an api_key requirement."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-cred-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=[],
|
||||
title="OAuth cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple API key credentials exist, the correct provider is selected."""
|
||||
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
|
||||
cred_openai = self._make_api_key_cred("openai", "openai-key")
|
||||
field_info = self._make_field_info(ProviderName.OPENAI)
|
||||
|
||||
result = find_matching_credential([cred_maps, cred_openai], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "openai-key"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (oauth2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingOAuth2Credential:
|
||||
"""Tests for find_matching_credential with OAuth2 credentials."""
|
||||
|
||||
def _make_oauth2_cred(
|
||||
self,
|
||||
provider: str = "google",
|
||||
scopes: list[str] | None = None,
|
||||
cred_id: str = "test-oauth2-id",
|
||||
) -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=scopes or [],
|
||||
title=f"OAuth2 cred for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self,
|
||||
provider: ProviderName = ProviderName.GOOGLE,
|
||||
required_scopes: frozenset[str] | None = None,
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An OAuth2 credential matching the provider should be returned."""
|
||||
cred = self._make_oauth2_cred("google")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An OAuth2 credential for a different provider should not match."""
|
||||
cred = self._make_oauth2_cred("github")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_with_required_scopes(self):
|
||||
"""An OAuth2 credential with all required scopes should match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_rejects_credential_with_insufficient_scopes(self):
|
||||
"""An OAuth2 credential missing required scopes should not match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_when_no_scopes_required(self):
|
||||
"""An OAuth2 credential should match when no scopes are required."""
|
||||
cred = self._make_oauth2_cred("google", scopes=[])
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple OAuth2 credentials exist, the correct one is selected."""
|
||||
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
|
||||
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
|
||||
field_info = self._make_field_info(ProviderName.GITHUB)
|
||||
|
||||
result = find_matching_credential([cred_google, cred_github], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "github-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (user_password)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingUserPasswordCredential:
|
||||
"""Tests for find_matching_credential with user/password credentials."""
|
||||
|
||||
def _make_user_password_cred(
|
||||
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
|
||||
) -> UserPasswordCredentials:
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title=f"Credentials for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.SMTP
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""A user/password credential matching the provider should be returned."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""A user/password credential for a different provider should not match."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An API key credential should not match a user_password requirement."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="api-key-cred-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("wrong-type-key"),
|
||||
title="API key cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([api_key_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple user/password credentials exist, the correct one is selected."""
|
||||
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
|
||||
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "hubspot-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (mixed credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingCredentialMixedTypes:
|
||||
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
|
||||
|
||||
def test_selects_api_key_from_mixed_list(self):
|
||||
"""API key requirement should skip OAuth2 and user_password credentials."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="openai",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="openai",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.OPENAI]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[oauth_cred, userpass_cred, api_key_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "apikey-id"
|
||||
|
||||
def test_selects_oauth2_from_mixed_list(self):
|
||||
"""OAuth2 requirement should skip API key and user_password credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="google",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="google",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, userpass_cred, oauth_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "oauth-id"
|
||||
|
||||
def test_selects_user_password_from_mixed_list(self):
|
||||
"""User/password requirement should skip API key and OAuth2 credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="smtp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.SMTP]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, oauth_cred, userpass_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "userpass-id"
|
||||
|
||||
def test_returns_none_when_only_wrong_types_available(self):
|
||||
"""Should return None when all available creds have the wrong type."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_block_credentials tests (integration — all credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveBlockCredentials:
|
||||
"""Integration tests for resolve_block_credentials across credential types."""
|
||||
|
||||
async def test_matches_host_scoped_credential_for_url(self):
|
||||
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "matching-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_when_no_matching_host(self):
|
||||
"""resolve_block_credentials should report missing creds when host doesn't match."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.stripe.com/v1/charges"}
|
||||
|
||||
wrong_host_cred = HostScopedCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="http",
|
||||
host="api.github.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="GitHub API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_host_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_reports_missing_when_no_credentials(self):
|
||||
"""resolve_block_credentials should report missing when user has no creds at all."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_api_key_credential_for_llm_block(self):
|
||||
"""resolve_block_credentials should match an API key cred for an LLM block."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
mock_cred = APIKeyCredentials(
|
||||
id="openai-key-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-test-key"),
|
||||
title="OpenAI API Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "openai-key-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_api_key_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when API key provider mismatches."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
wrong_provider_cred = APIKeyCredentials(
|
||||
id="wrong-key-id",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr("sk-wrong"),
|
||||
title="Google Maps Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_provider_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_oauth2_credential_for_google_block(self):
|
||||
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
|
||||
block = GmailReadBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = OAuth2Credentials(
|
||||
id="google-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
access_token_expires_at=9999999999,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "google-oauth-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
|
||||
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
|
||||
from backend.blocks.google.gmail import GmailSendBlock
|
||||
|
||||
block = GmailSendBlock()
|
||||
input_data = {}
|
||||
|
||||
# GmailSendBlock requires gmail.send scope; provide only readonly
|
||||
insufficient_cred = OAuth2Credentials(
|
||||
id="limited-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth (limited)",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[insufficient_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_user_password_credential_for_email_block(self):
|
||||
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = UserPasswordCredentials(
|
||||
id="smtp-cred-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title="SMTP Credentials",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "smtp-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_user_password_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when user/password provider mismatches."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
wrong_cred = UserPasswordCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="dataforseo",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
title="DataForSEO Creds",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool integration tests for authenticated HTTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBlockToolAuthenticatedHttp:
|
||||
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
|
||||
|
||||
async def test_returns_setup_requirements_when_creds_missing(self):
|
||||
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={"url": "https://api.example.com/data", "method": "GET"},
|
||||
)
|
||||
|
||||
assert isinstance(response, SetupRequirementsResponse)
|
||||
assert "credentials" in response.message.lower()
|
||||
|
||||
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
|
||||
"""When creds present + required inputs missing -> BlockDetailsResponse.
|
||||
|
||||
Note: with input_data={}, no URL is provided so discriminator_values is
|
||||
empty, meaning _credential_is_for_host() matches any host-scoped
|
||||
credential vacuously. This test exercises the "creds present + inputs
|
||||
missing" branch, not host-based matching (which is covered by
|
||||
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
# Call with empty input to get schema
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.name == block.name
|
||||
# The matched credential should be included in the details
|
||||
assert len(response.block.credentials) > 0
|
||||
assert response.block.credentials[0].id == "matching-cred-id"
|
||||
@@ -121,7 +121,7 @@ def _serialize_missing_credential(
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
scopes = sorted(field_info.required_scopes or [])
|
||||
|
||||
return {
|
||||
result: dict[str, Any] = {
|
||||
"id": field_key,
|
||||
"title": field_key.replace("_", " ").title(),
|
||||
"provider": provider,
|
||||
@@ -131,6 +131,17 @@ def _serialize_missing_credential(
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
# Include discriminator info so the frontend can auto-match
|
||||
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
|
||||
if field_info.discriminator:
|
||||
result["discriminator"] = field_info.discriminator
|
||||
if field_info.discriminator_values:
|
||||
result["discriminator_values"] = sorted(
|
||||
str(v) for v in field_info.discriminator_values
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_missing_credentials_from_graph(
|
||||
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None
|
||||
|
||||
@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
set_turn_duration = _(chat_db.set_turn_duration)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
set_turn_duration = d.set_turn_duration
|
||||
|
||||
@@ -342,6 +342,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
and "name" in exec.input_data
|
||||
)
|
||||
}
|
||||
),
|
||||
@@ -360,8 +361,10 @@ class GraphExecution(GraphExecutionMeta):
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
for exec in complete_node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.OUTPUT
|
||||
and "name" in exec.input_data
|
||||
):
|
||||
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
||||
|
||||
return GraphExecution(
|
||||
|
||||
@@ -581,7 +581,6 @@ class GraphModel(Graph, GraphMeta):
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
@@ -1529,6 +1528,28 @@ async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphMo
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
# Auto-increment version for any graph entry (parent or sub-graph) whose
|
||||
# (id, version) already exists. This prevents UniqueViolationError when
|
||||
# the copilot re-saves an agent that already exists at the requested version.
|
||||
# NOTE: This issues one find_first query per graph entry (N+1 pattern).
|
||||
# Sub-graph counts are typically small (< 5), so the overhead is negligible.
|
||||
for g in graphs:
|
||||
existing = await AgentGraph.prisma(tx).find_first(
|
||||
where={"id": g.id},
|
||||
order={"version": "desc"},
|
||||
)
|
||||
if existing and existing.version >= g.version:
|
||||
old_version = g.version
|
||||
g.version = existing.version + 1
|
||||
logger.warning(
|
||||
"Auto-incremented graph %s version from %d to %d "
|
||||
"(version %d already exists)",
|
||||
g.id,
|
||||
old_version,
|
||||
g.version,
|
||||
existing.version,
|
||||
)
|
||||
|
||||
await AgentGraph.prisma(tx).create_many(
|
||||
data=[
|
||||
AgentGraphCreateInput(
|
||||
|
||||
@@ -722,7 +722,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
discriminator_values=set(self.discriminator_values), # defensive copy
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,9 @@ _MAX_PAGES = 100
|
||||
# LLM extraction timeout (seconds)
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
SUGGESTION_THEMES = ["Learn", "Create", "Automate", "Organize"]
|
||||
PROMPTS_PER_THEME = 5
|
||||
|
||||
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
@@ -332,6 +335,11 @@ Fields:
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
- suggested_prompts (object with keys "Learn", "Create", "Automate", "Organize"): for each key, \
|
||||
provide a list of 5 short action prompts (each under 20 words) that would help this person. \
|
||||
"Learn" = questions about AutoGPT features; "Create" = content/document generation tasks; \
|
||||
"Automate" = recurring workflow automation ideas; "Organize" = structuring/prioritizing tasks. \
|
||||
Should be specific to their industry, role, and pain points; actionable and conversational in tone.
|
||||
|
||||
Form data:
|
||||
"""
|
||||
@@ -378,6 +386,29 @@ async def extract_business_understanding(
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
# Validate suggested_prompts: themed dict, filter >20 words, cap at 5 per theme
|
||||
raw_prompts = cleaned.get("suggested_prompts", {})
|
||||
if isinstance(raw_prompts, dict):
|
||||
themed: dict[str, list[str]] = {}
|
||||
for theme in SUGGESTION_THEMES:
|
||||
theme_prompts = raw_prompts.get(theme, [])
|
||||
if not isinstance(theme_prompts, list):
|
||||
continue
|
||||
valid = [
|
||||
s
|
||||
for p in theme_prompts
|
||||
if isinstance(p, str) and (s := p.strip()) and len(s.split()) <= 20
|
||||
]
|
||||
if valid:
|
||||
themed[theme] = valid[:PROMPTS_PER_THEME]
|
||||
if themed:
|
||||
cleaned["suggested_prompts"] = themed
|
||||
else:
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
else:
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
|
||||
@@ -284,6 +284,7 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_input.suggested_prompts = {"Learn": ["P1"], "Create": ["P2"]}
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -397,15 +398,25 @@ def test_extraction_prompt_no_format_placeholders():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
async def test_extract_business_understanding_themed_prompts():
|
||||
"""Happy path: LLM returns themed prompts as dict."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"business_name": "Acme Corp",
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
"suggested_prompts": {
|
||||
"Learn": ["Learn 1", "Learn 2", "Learn 3", "Learn 4", "Learn 5"],
|
||||
"Create": [
|
||||
"Create 1",
|
||||
"Create 2",
|
||||
"Create 3",
|
||||
"Create 4",
|
||||
"Create 5",
|
||||
],
|
||||
"Automate": ["Auto 1", "Auto 2", "Auto 3", "Auto 4", "Auto 5"],
|
||||
"Organize": ["Org 1", "Org 2", "Org 3", "Org 4", "Org 5"],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
@@ -418,9 +429,42 @@ async def test_extract_business_understanding_success():
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
assert result.suggested_prompts is not None
|
||||
assert len(result.suggested_prompts) == 4
|
||||
assert len(result.suggested_prompts["Learn"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_themed_prompts_filters_long_and_unknown_keys():
|
||||
"""Long prompts are filtered, unknown keys are dropped, each theme capped at 5."""
|
||||
long_prompt = " ".join(["word"] * 21)
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"suggested_prompts": {
|
||||
"Learn": [long_prompt, "Valid learn 1", "Valid learn 2"],
|
||||
"UnknownTheme": ["Should be dropped"],
|
||||
"Automate": ["A1", "A2", "A3", "A4", "A5", "A6"],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.suggested_prompts is not None
|
||||
# Unknown key dropped
|
||||
assert "UnknownTheme" not in result.suggested_prompts
|
||||
# Long prompt filtered
|
||||
assert result.suggested_prompts["Learn"] == ["Valid learn 1", "Valid learn 2"]
|
||||
# Capped at 5
|
||||
assert result.suggested_prompts["Automate"] == ["A1", "A2", "A3", "A4", "A5"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -49,6 +49,25 @@ def _json_to_list(value: Any) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _json_to_themed_prompts(value: Any) -> dict[str, list[str]]:
|
||||
"""Convert Json field to themed prompts dict.
|
||||
|
||||
Handles both the new ``dict[str, list[str]]`` format and the legacy
|
||||
``list[str]`` format. Legacy rows are placed under a ``"General"`` key so
|
||||
existing personalised prompts remain readable until a backfill regenerates
|
||||
them into the proper themed shape.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: [i for i in v if isinstance(i, str)]
|
||||
for k, v in value.items()
|
||||
if isinstance(k, str) and isinstance(v, list)
|
||||
}
|
||||
if isinstance(value, list) and value:
|
||||
return {"General": [str(p) for p in value if isinstance(p, str)]}
|
||||
return {}
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
@@ -104,6 +123,11 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: Optional[dict[str, list[str]]] = pydantic.Field(
|
||||
None, description="LLM-generated suggested prompts grouped by theme"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
@@ -140,6 +164,9 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: dict[str, list[str]] = pydantic.Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
@@ -167,6 +194,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
suggested_prompts=_json_to_themed_prompts(data.get("suggested_prompts")),
|
||||
)
|
||||
|
||||
|
||||
@@ -246,33 +274,22 @@ async def get_business_understanding(
|
||||
return understanding
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
) -> dict[str, Any]:
|
||||
"""Merge new input into existing data dict using incremental strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
- suggested_prompts: fully replaced if provided (not None)
|
||||
|
||||
Data is stored as: {name: ..., business: {version: 1, ...}}
|
||||
Returns the merged data dict (mutates and returns *existing_data*).
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
@@ -310,16 +327,48 @@ async def upsert_business_understanding(
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Suggested prompts - fully replace if provided
|
||||
if input_data.suggested_prompts is not None:
|
||||
existing_data["suggested_prompts"] = input_data.suggested_prompts
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
return existing_data
|
||||
|
||||
|
||||
async def upsert_business_understanding(
|
||||
user_id: str,
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> BusinessUnderstanding:
|
||||
"""
|
||||
Create or update business understanding with incremental merge strategy.
|
||||
|
||||
- String fields: new value overwrites if provided (not None)
|
||||
- List fields: new items are appended to existing (deduplicated)
|
||||
- suggested_prompts: fully replaced if provided (not None)
|
||||
|
||||
Data is stored as: {name: ..., business: {version: 1, ...}}
|
||||
"""
|
||||
# Get existing record for merge
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
148
autogpt_platform/backend/backend/data/understanding_test.py
Normal file
148
autogpt_platform/backend/backend/data/understanding_test.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Tests for business understanding merge and format logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
_json_to_themed_prompts,
|
||||
format_understanding_for_prompt,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
|
||||
"""Create a BusinessUnderstandingInput with only the specified fields."""
|
||||
return BusinessUnderstandingInput.model_validate(kwargs)
|
||||
|
||||
|
||||
# ─── merge_business_understanding_data: themed prompts ─────────────────
|
||||
|
||||
|
||||
def test_merge_themed_prompts_overwrites_existing():
|
||||
"""New themed prompts should fully replace existing ones (not merge)."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": {
|
||||
"Learn": ["Old learn prompt"],
|
||||
"Create": ["Old create prompt"],
|
||||
},
|
||||
}
|
||||
new_prompts = {
|
||||
"Automate": ["Schedule daily reports", "Set up email alerts"],
|
||||
"Organize": ["Sort inbox by priority"],
|
||||
}
|
||||
input_data = _make_input(suggested_prompts=new_prompts)
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == new_prompts
|
||||
|
||||
|
||||
def test_merge_themed_prompts_none_preserves_existing():
|
||||
"""When input has suggested_prompts=None, existing themed prompts are preserved."""
|
||||
existing_prompts = {
|
||||
"Learn": ["How to automate?"],
|
||||
"Create": ["Build a chatbot"],
|
||||
}
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": existing_prompts,
|
||||
}
|
||||
input_data = _make_input(industry="Finance")
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == existing_prompts
|
||||
assert result["business"]["industry"] == "Finance"
|
||||
|
||||
|
||||
# ─── from_db: themed prompts deserialization ───────────────────────────
|
||||
|
||||
|
||||
def test_from_db_themed_prompts():
|
||||
"""from_db correctly deserializes a themed dict for suggested_prompts."""
|
||||
themed = {
|
||||
"Learn": ["What can I automate?"],
|
||||
"Create": ["Build a workflow"],
|
||||
}
|
||||
db_record = MagicMock()
|
||||
db_record.id = "test-id"
|
||||
db_record.userId = "user-1"
|
||||
db_record.createdAt = datetime.now(tz=timezone.utc)
|
||||
db_record.updatedAt = datetime.now(tz=timezone.utc)
|
||||
db_record.data = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": themed,
|
||||
}
|
||||
|
||||
result = BusinessUnderstanding.from_db(db_record)
|
||||
|
||||
assert result.suggested_prompts == themed
|
||||
|
||||
|
||||
def test_from_db_legacy_list_prompts_preserved_under_general():
|
||||
"""from_db preserves legacy list[str] prompts under a 'General' key."""
|
||||
db_record = MagicMock()
|
||||
db_record.id = "test-id"
|
||||
db_record.userId = "user-1"
|
||||
db_record.createdAt = datetime.now(tz=timezone.utc)
|
||||
db_record.updatedAt = datetime.now(tz=timezone.utc)
|
||||
db_record.data = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
|
||||
}
|
||||
|
||||
result = BusinessUnderstanding.from_db(db_record)
|
||||
|
||||
assert result.suggested_prompts == {"General": ["Old prompt 1", "Old prompt 2"]}
|
||||
|
||||
|
||||
# ─── _json_to_themed_prompts helper ───────────────────────────────────
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_dict():
|
||||
value = {"Learn": ["a", "b"], "Create": ["c"]}
|
||||
assert _json_to_themed_prompts(value) == {"Learn": ["a", "b"], "Create": ["c"]}
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_list_returns_general():
|
||||
assert _json_to_themed_prompts(["a", "b"]) == {"General": ["a", "b"]}
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_none_returns_empty():
|
||||
assert _json_to_themed_prompts(None) == {}
|
||||
|
||||
|
||||
# ─── format_understanding_for_prompt: excludes themed prompts ──────────
|
||||
|
||||
|
||||
def test_format_understanding_excludes_themed_prompts():
|
||||
"""Themed suggested_prompts are UI-only and must NOT appear in the system prompt."""
|
||||
understanding = BusinessUnderstanding(
|
||||
id="test-id",
|
||||
user_id="user-1",
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
updated_at=datetime.now(tz=timezone.utc),
|
||||
user_name="Alice",
|
||||
industry="Technology",
|
||||
suggested_prompts={
|
||||
"Learn": ["Automate reports"],
|
||||
"Create": ["Set up alerts", "Track KPIs"],
|
||||
},
|
||||
)
|
||||
|
||||
formatted = format_understanding_for_prompt(understanding)
|
||||
|
||||
assert "Alice" in formatted
|
||||
assert "Technology" in formatted
|
||||
assert "suggested_prompts" not in formatted
|
||||
assert "Automate reports" not in formatted
|
||||
assert "Set up alerts" not in formatted
|
||||
assert "Track KPIs" not in formatted
|
||||
@@ -76,20 +76,24 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
|
||||
"""
|
||||
Get user's workspace, creating one if it doesn't exist.
|
||||
|
||||
Uses upsert to atomically handle concurrent creation attempts.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
Workspace instance
|
||||
"""
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
if workspace:
|
||||
return Workspace.from_db(workspace)
|
||||
|
||||
try:
|
||||
workspace = await UserWorkspace.prisma().create(data={"userId": user_id})
|
||||
workspace = await UserWorkspace.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id},
|
||||
"update": {}, # No-op update; workspace already exists
|
||||
},
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Concurrent request already created it
|
||||
# Defense-in-depth: should not happen with upsert, but handle gracefully
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
if workspace is None:
|
||||
raise
|
||||
@@ -125,6 +129,13 @@ async def create_workspace_file(
|
||||
"""
|
||||
Create a new workspace file record.
|
||||
|
||||
Raises ``UniqueViolationError`` if a record with the same
|
||||
``(workspaceId, path)`` already exists. The caller
|
||||
(``WorkspaceManager._persist_db_record``) relies on this to trigger
|
||||
its delete-old-file-then-retry flow, which cleans up the old storage
|
||||
blob before re-creating the DB record. Using ``upsert`` here would
|
||||
silently overwrite ``storagePath`` and orphan the old blob in storage.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
file_id: The file ID (same as used in storage path for consistency)
|
||||
|
||||
@@ -39,6 +39,8 @@ class Flag(str, Enum):
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
|
||||
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -84,6 +84,14 @@ def _before_send(event, hint):
|
||||
):
|
||||
return None
|
||||
|
||||
# Prisma UniqueViolationError — always caught and handled in our codebase.
|
||||
# These arise from concurrent create operations racing on unique constraints
|
||||
# (workspace files, credits, library folders, store listings, chat messages).
|
||||
# Every call site has an except handler; the global FastAPI handler also
|
||||
# catches them and returns 400. Safe to drop unconditionally.
|
||||
if exc_type and exc_type.__name__ == "UniqueViolationError":
|
||||
return None
|
||||
|
||||
# Google metadata DNS errors — expected in non-GCP environments
|
||||
if (
|
||||
"metadata.google.internal" in exc_msg
|
||||
|
||||
@@ -227,10 +227,16 @@ class AppService(BaseAppService, ABC):
|
||||
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.error(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc if status_code == 500 else None,
|
||||
)
|
||||
if status_code >= 500:
|
||||
logger.error(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
return responses.JSONResponse(
|
||||
status_code=status_code,
|
||||
content=RemoteCallError(
|
||||
@@ -563,7 +569,6 @@ def get_service_client(
|
||||
self._connection_failure_count >= 3
|
||||
and current_time - self._last_client_reset > 30
|
||||
):
|
||||
|
||||
logger.warning(
|
||||
f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients"
|
||||
)
|
||||
|
||||
@@ -108,6 +108,9 @@ class VirusScannerService:
|
||||
return VirusScanResult(
|
||||
is_clean=True, scan_time_ms=0, file_size=len(content)
|
||||
)
|
||||
if len(content) == 0:
|
||||
logger.debug(f"Skipping virus scan for empty file {filename}")
|
||||
return VirusScanResult(is_clean=True, scan_time_ms=0, file_size=0)
|
||||
if len(content) > self.settings.max_scan_size:
|
||||
logger.warning(
|
||||
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
|
||||
@@ -123,7 +126,7 @@ class VirusScannerService:
|
||||
raise RuntimeError("ClamAV service is unreachable")
|
||||
|
||||
start = time.monotonic()
|
||||
chunk_size = len(content) # Start with full content length
|
||||
chunk_size = max(1, len(content)) # Start with full content length
|
||||
for retry in range(self.settings.max_retries):
|
||||
# For small files, don't check min_chunk_size limit
|
||||
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
|
||||
@@ -212,5 +215,5 @@ async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> Non
|
||||
except VirusDetectedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
|
||||
logger.warning(f"Virus scanning failed for {filename}: {str(e)}")
|
||||
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e
|
||||
|
||||
@@ -85,7 +85,36 @@ class TestVirusScannerService:
|
||||
)
|
||||
assert result_dirty.is_clean is False
|
||||
|
||||
# Note: ping method was removed from current implementation
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_empty_file(self, scanner):
|
||||
"""Empty files (0 bytes) should be accepted without hitting ClamAV."""
|
||||
content = b""
|
||||
result = await scanner.scan_file(content, filename="empty.png")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.threat_name is None
|
||||
assert result.file_size == 0
|
||||
assert result.scan_time_ms == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_single_byte_file(self, scanner):
|
||||
"""A 1-byte file should be scanned normally (regression: chunk_size must not be 0)."""
|
||||
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001)
|
||||
return None
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"\x00"
|
||||
result = await scanner.scan_file(content, filename="tiny.bin")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.file_size == 1
|
||||
assert result.scan_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_clean_file(self, scanner):
|
||||
@@ -251,3 +280,27 @@ class TestHelperFunctions:
|
||||
|
||||
with pytest.raises(VirusScanError, match="Virus scanning failed"):
|
||||
await scan_content_safe(b"test content", filename="test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_logs_warning_not_error_on_failure(self):
|
||||
"""Scan failures should log at WARNING level, not ERROR, to avoid paging on-call."""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.side_effect = Exception(
|
||||
"range() arg 3 must not be zero"
|
||||
)
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
with (
|
||||
pytest.raises(VirusScanError),
|
||||
patch("backend.util.virus_scanner.logger") as mock_logger,
|
||||
):
|
||||
await scan_content_safe(b"test", filename="screenshot.png")
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
# Check the formatted log message contains the error text.
|
||||
# Use str() to handle both f-string and %-style logging formats.
|
||||
log_msg = str(mock_logger.warning.call_args)
|
||||
assert "range()" in log_msg
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Add durationMs column to ChatMessage for persisting turn elapsed time.
|
||||
ALTER TABLE "ChatMessage" ADD COLUMN "durationMs" INTEGER;
|
||||
@@ -0,0 +1,44 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "PlatformType" AS ENUM ('DISCORD', 'TELEGRAM', 'SLACK', 'TEAMS', 'WHATSAPP', 'GITHUB', 'LINEAR');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformLink" (
|
||||
"id" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLink_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformLinkToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"channelId" TEXT,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLinkToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLink_platform_platformUserId_key" ON "PlatformLink"("platform", "platformUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLink_userId_idx" ON "PlatformLink"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLinkToken_token_key" ON "PlatformLinkToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLinkToken_expiresAt_idx" ON "PlatformLinkToken"("expiresAt");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformLink" ADD CONSTRAINT "PlatformLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -71,6 +71,9 @@ model User {
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
|
||||
// Platform bot linking
|
||||
PlatformLinks PlatformLink[]
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -246,7 +249,8 @@ model ChatMessage {
|
||||
functionCall Json? // Deprecated but kept for compatibility
|
||||
|
||||
// Ordering within session
|
||||
sequence Int
|
||||
sequence Int
|
||||
durationMs Int? // Wall-clock milliseconds for this assistant turn
|
||||
|
||||
@@unique([sessionId, sequence])
|
||||
}
|
||||
@@ -1301,3 +1305,50 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ── Platform Bot Linking ──────────────────────────────────────────────
|
||||
// Links external chat platform identities (Discord, Telegram, Slack, etc.)
|
||||
// to AutoGPT user accounts, enabling the multi-platform CoPilot bot.
|
||||
|
||||
enum PlatformType {
|
||||
DISCORD
|
||||
TELEGRAM
|
||||
SLACK
|
||||
TEAMS
|
||||
WHATSAPP
|
||||
GITHUB
|
||||
LINEAR
|
||||
}
|
||||
|
||||
// Maps a platform user identity to an AutoGPT account.
|
||||
// One AutoGPT user can have multiple platform links (e.g. Discord + Telegram).
|
||||
// Each platform identity can only link to one AutoGPT account.
|
||||
model PlatformLink {
|
||||
id String @id @default(uuid())
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
platform PlatformType
|
||||
platformUserId String // The user's ID on that platform
|
||||
platformUsername String? // Display name (best-effort, may go stale)
|
||||
linkedAt DateTime @default(now())
|
||||
|
||||
@@unique([platform, platformUserId])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// One-time tokens for the account linking flow.
|
||||
// Generated when an unlinked user messages the bot; consumed when they
|
||||
// complete the link on the AutoGPT web app.
|
||||
model PlatformLinkToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique
|
||||
platform PlatformType
|
||||
platformUserId String
|
||||
platformUsername String?
|
||||
channelId String? // So the bot can send a confirmation message
|
||||
expiresAt DateTime
|
||||
usedAt DateTime?
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
8
autogpt_platform/backend/snapshots/get_rate_limit
Normal file
8
autogpt_platform/backend/snapshots/get_rate_limit
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"daily_token_limit": 2500000,
|
||||
"daily_tokens_used": 500000,
|
||||
"user_email": "target@example.com",
|
||||
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
|
||||
"weekly_token_limit": 12500000,
|
||||
"weekly_tokens_used": 3000000
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"daily_token_limit": 2500000,
|
||||
"daily_tokens_used": 0,
|
||||
"user_email": "target@example.com",
|
||||
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
|
||||
"weekly_token_limit": 12500000,
|
||||
"weekly_tokens_used": 0
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"daily_token_limit": 2500000,
|
||||
"daily_tokens_used": 0,
|
||||
"user_email": "target@example.com",
|
||||
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
|
||||
"weekly_token_limit": 12500000,
|
||||
"weekly_tokens_used": 3000000
|
||||
}
|
||||
@@ -526,7 +526,12 @@ class TestValidateOrchestratorBlocks:
|
||||
"id": AGENT_INPUT_BLOCK_ID,
|
||||
"name": "AgentInputBlock",
|
||||
"inputSchema": {
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
"description": {"type": "string"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
"outputSchema": {"properties": {"result": {}}},
|
||||
@@ -537,6 +542,7 @@ class TestValidateOrchestratorBlocks:
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
},
|
||||
"required": ["name"],
|
||||
@@ -683,7 +689,12 @@ class TestOrchestratorE2EPipeline:
|
||||
"id": AGENT_INPUT_BLOCK_ID,
|
||||
"name": "AgentInputBlock",
|
||||
"inputSchema": {
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
"description": {"type": "string"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
"outputSchema": {"properties": {"result": {}}},
|
||||
@@ -694,6 +705,7 @@ class TestOrchestratorE2EPipeline:
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
},
|
||||
"required": ["name"],
|
||||
|
||||
@@ -254,7 +254,6 @@ class TestDataCreator:
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": None,
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": -1012, "y": 674}},
|
||||
)
|
||||
@@ -274,7 +273,6 @@ class TestDataCreator:
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": None,
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": -1117, "y": 78}},
|
||||
)
|
||||
|
||||
24
autogpt_platform/copilot-bot/.env.example
Normal file
24
autogpt_platform/copilot-bot/.env.example
Normal file
@@ -0,0 +1,24 @@
|
||||
# ── AutoGPT Platform API ──────────────────────────────────────────────
|
||||
AUTOGPT_API_URL=http://localhost:8006
|
||||
# Service API key for bot-facing endpoints (must match backend PLATFORM_BOT_API_KEY)
|
||||
# PLATFORM_BOT_API_KEY=
|
||||
|
||||
# ── Discord ───────────────────────────────────────────────────────────
|
||||
DISCORD_BOT_TOKEN=
|
||||
DISCORD_PUBLIC_KEY=
|
||||
DISCORD_APPLICATION_ID=
|
||||
# Optional: comma-separated role IDs that trigger mentions
|
||||
# DISCORD_MENTION_ROLE_IDS=
|
||||
|
||||
# ── Telegram ──────────────────────────────────────────────────────────
|
||||
# TELEGRAM_BOT_TOKEN=
|
||||
|
||||
# ── Slack ─────────────────────────────────────────────────────────────
|
||||
# SLACK_BOT_TOKEN=
|
||||
# SLACK_SIGNING_SECRET=
|
||||
# SLACK_APP_TOKEN=
|
||||
|
||||
# ── State ─────────────────────────────────────────────────────────────
|
||||
# For production, use Redis:
|
||||
# REDIS_URL=redis://localhost:6379
|
||||
# For development, in-memory state is used by default.
|
||||
4
autogpt_platform/copilot-bot/.gitignore
vendored
Normal file
4
autogpt_platform/copilot-bot/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
node_modules/
|
||||
dist/
|
||||
.env
|
||||
*.log
|
||||
73
autogpt_platform/copilot-bot/README.md
Normal file
73
autogpt_platform/copilot-bot/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# CoPilot Bot
|
||||
|
||||
Multi-platform bot service for AutoGPT CoPilot, built with [Vercel Chat SDK](https://chat-sdk.dev).
|
||||
|
||||
Deploys CoPilot to Discord, Telegram, Slack, and more from a single codebase.
|
||||
|
||||
## How it works
|
||||
|
||||
1. User messages the bot on any platform (Discord, Telegram, Slack)
|
||||
2. Bot checks if the platform user is linked to an AutoGPT account
|
||||
3. If not linked → sends a one-time link URL
|
||||
4. User clicks → logs in to AutoGPT → accounts are linked
|
||||
5. Future messages are forwarded to CoPilot and responses streamed back
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
npm install
|
||||
|
||||
# Copy env template
|
||||
cp .env.example .env
|
||||
|
||||
# Configure at least one platform adapter (e.g. Discord)
|
||||
# Edit .env with your bot tokens
|
||||
|
||||
# Run in development
|
||||
npm run dev
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
src/
|
||||
├── index.ts # Standalone entry point
|
||||
├── config.ts # Environment-based configuration
|
||||
├── bot.ts # Core bot logic (Chat SDK handlers)
|
||||
├── platform-api.ts # AutoGPT platform API client
|
||||
└── api/ # Serverless API routes (Vercel)
|
||||
├── _bot.ts # Singleton bot instance
|
||||
├── webhooks/ # Platform webhook endpoints
|
||||
│ ├── discord.ts
|
||||
│ ├── telegram.ts
|
||||
│ └── slack.ts
|
||||
└── gateway/
|
||||
└── discord.ts # Gateway cron for Discord messages
|
||||
```
|
||||
|
||||
## Deployment
|
||||
|
||||
### Standalone (Docker/PM2)
|
||||
```bash
|
||||
npm run build
|
||||
npm start
|
||||
```
|
||||
|
||||
### Serverless (Vercel)
|
||||
Deploy to Vercel. Webhook URLs:
|
||||
- Discord: `https://your-app.vercel.app/api/webhooks/discord`
|
||||
- Telegram: `https://your-app.vercel.app/api/webhooks/telegram`
|
||||
- Slack: `https://your-app.vercel.app/api/webhooks/slack`
|
||||
|
||||
For Discord messages (Gateway), add a cron job in `vercel.json`:
|
||||
```json
|
||||
{
|
||||
"crons": [{ "path": "/api/gateway/discord", "schedule": "*/9 * * * *" }]
|
||||
}
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- [Chat SDK](https://chat-sdk.dev) — Cross-platform bot abstraction
|
||||
- AutoGPT Platform API — Account linking + CoPilot chat
|
||||
27
autogpt_platform/copilot-bot/package.json
Normal file
27
autogpt_platform/copilot-bot/package.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "@autogpt/copilot-bot",
|
||||
"version": "0.1.0",
|
||||
"description": "Multi-platform CoPilot bot service using Vercel Chat SDK",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "tsx watch src/index.ts",
|
||||
"start": "tsx src/index.ts",
|
||||
"build": "tsc",
|
||||
"lint": "tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"chat": "^4.23.0",
|
||||
"@chat-adapter/discord": "^4.23.0",
|
||||
"@chat-adapter/telegram": "^4.23.0",
|
||||
"@chat-adapter/slack": "^4.23.0",
|
||||
"@chat-adapter/state-memory": "^4.23.0",
|
||||
"@chat-adapter/state-redis": "^4.23.0",
|
||||
"dotenv": "^16.4.0",
|
||||
"eventsource-parser": "^3.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.0.0",
|
||||
"tsx": "^4.19.0",
|
||||
"typescript": "^5.7.0"
|
||||
}
|
||||
}
|
||||
31
autogpt_platform/copilot-bot/src/api/_bot.ts
Normal file
31
autogpt_platform/copilot-bot/src/api/_bot.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Singleton bot instance for serverless environments.
|
||||
*
|
||||
* In serverless (Vercel), each request may hit a cold or warm instance.
|
||||
* We create the bot once per instance and reuse it across requests.
|
||||
*/
|
||||
|
||||
import { loadConfig } from "../config.js";
|
||||
import { createBot } from "../bot.js";
|
||||
|
||||
let _botPromise: ReturnType<typeof createBot> | null = null;
|
||||
|
||||
export async function getBotInstance() {
|
||||
if (!_botPromise) {
|
||||
const config = loadConfig();
|
||||
|
||||
let stateAdapter;
|
||||
if (config.redisUrl) {
|
||||
const { createRedisState } = await import("@chat-adapter/state-redis");
|
||||
stateAdapter = createRedisState({ url: config.redisUrl });
|
||||
} else {
|
||||
const { createMemoryState } = await import("@chat-adapter/state-memory");
|
||||
stateAdapter = createMemoryState();
|
||||
}
|
||||
|
||||
_botPromise = createBot(config, stateAdapter);
|
||||
console.log("[bot] Instance created (serverless)");
|
||||
}
|
||||
|
||||
return _botPromise;
|
||||
}
|
||||
39
autogpt_platform/copilot-bot/src/api/gateway/discord.ts
Normal file
39
autogpt_platform/copilot-bot/src/api/gateway/discord.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Discord Gateway cron endpoint.
|
||||
*
|
||||
* In serverless environments, Discord's Gateway WebSocket needs a
|
||||
* persistent connection to receive messages. This endpoint is called
|
||||
* by a cron job every 9 minutes to maintain the connection.
|
||||
*/
|
||||
|
||||
import { getBotInstance } from "../_bot.js";
|
||||
|
||||
export async function GET(request: Request) {
|
||||
// Verify cron secret in production
|
||||
const authHeader = request.headers.get("authorization");
|
||||
if (
|
||||
process.env.CRON_SECRET &&
|
||||
authHeader !== `Bearer ${process.env.CRON_SECRET}`
|
||||
) {
|
||||
return new Response("Unauthorized", { status: 401 });
|
||||
}
|
||||
|
||||
const bot = await getBotInstance();
|
||||
await bot.initialize();
|
||||
|
||||
const discord = bot.getAdapter("discord");
|
||||
if (!discord) {
|
||||
return new Response("Discord adapter not configured", { status: 404 });
|
||||
}
|
||||
|
||||
const baseUrl = process.env.WEBHOOK_BASE_URL ?? "http://localhost:3000";
|
||||
const webhookUrl = `${baseUrl}/api/webhooks/discord`;
|
||||
const durationMs = 9 * 60 * 1000; // 9 minutes (matches cron interval)
|
||||
|
||||
return (discord as any).startGatewayListener(
|
||||
{},
|
||||
durationMs,
|
||||
undefined,
|
||||
webhookUrl
|
||||
);
|
||||
}
|
||||
17
autogpt_platform/copilot-bot/src/api/webhooks/discord.ts
Normal file
17
autogpt_platform/copilot-bot/src/api/webhooks/discord.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Discord webhook endpoint.
|
||||
* Deploy as: POST /api/webhooks/discord
|
||||
*/
|
||||
|
||||
import { getBotInstance } from "../_bot.js";
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const bot = await getBotInstance();
|
||||
const handler = bot.webhooks.discord;
|
||||
|
||||
if (!handler) {
|
||||
return new Response("Discord adapter not configured", { status: 404 });
|
||||
}
|
||||
|
||||
return handler(request);
|
||||
}
|
||||
17
autogpt_platform/copilot-bot/src/api/webhooks/slack.ts
Normal file
17
autogpt_platform/copilot-bot/src/api/webhooks/slack.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Slack webhook endpoint.
|
||||
* Deploy as: POST /api/webhooks/slack
|
||||
*/
|
||||
|
||||
import { getBotInstance } from "../_bot.js";
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const bot = await getBotInstance();
|
||||
const handler = bot.webhooks.slack;
|
||||
|
||||
if (!handler) {
|
||||
return new Response("Slack adapter not configured", { status: 404 });
|
||||
}
|
||||
|
||||
return handler(request);
|
||||
}
|
||||
17
autogpt_platform/copilot-bot/src/api/webhooks/telegram.ts
Normal file
17
autogpt_platform/copilot-bot/src/api/webhooks/telegram.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Telegram webhook endpoint.
|
||||
* Deploy as: POST /api/webhooks/telegram
|
||||
*/
|
||||
|
||||
import { getBotInstance } from "../_bot.js";
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const bot = await getBotInstance();
|
||||
const handler = bot.webhooks.telegram;
|
||||
|
||||
if (!handler) {
|
||||
return new Response("Telegram adapter not configured", { status: 404 });
|
||||
}
|
||||
|
||||
return handler(request);
|
||||
}
|
||||
232
autogpt_platform/copilot-bot/src/bot.ts
Normal file
232
autogpt_platform/copilot-bot/src/bot.ts
Normal file
@@ -0,0 +1,232 @@
|
||||
/**
|
||||
* CoPilot Bot — Multi-platform bot using Vercel Chat SDK.
|
||||
*
|
||||
* Handles:
|
||||
* - Account linking (prompts unlinked users to link)
|
||||
* - Message routing to CoPilot API
|
||||
* - Streaming responses back to the user
|
||||
*/
|
||||
|
||||
import { Chat, Message } from "chat";
|
||||
import type { Adapter, StateAdapter, Thread } from "chat";
|
||||
import { PlatformAPI } from "./platform-api.js";
|
||||
import type { Config } from "./config.js";
|
||||
|
||||
// Thread state persisted across messages
|
||||
export interface BotThreadState {
|
||||
/** Linked AutoGPT user ID */
|
||||
userId?: string;
|
||||
/** CoPilot chat session ID for this thread */
|
||||
sessionId?: string;
|
||||
/** Pending link token (if user hasn't linked yet) */
|
||||
pendingLinkToken?: string;
|
||||
}
|
||||
|
||||
type BotThread = Thread<BotThreadState>;
|
||||
|
||||
export async function createBot(config: Config, stateAdapter: StateAdapter) {
|
||||
const api = new PlatformAPI(config.autogptApiUrl);
|
||||
|
||||
// Build adapters based on config
|
||||
const adapters: Record<string, Adapter> = {};
|
||||
|
||||
if (config.discord) {
|
||||
const { createDiscordAdapter } = await import("@chat-adapter/discord");
|
||||
adapters.discord = createDiscordAdapter();
|
||||
}
|
||||
|
||||
if (config.telegram) {
|
||||
const { createTelegramAdapter } = await import("@chat-adapter/telegram");
|
||||
adapters.telegram = createTelegramAdapter();
|
||||
}
|
||||
|
||||
if (config.slack) {
|
||||
const { createSlackAdapter } = await import("@chat-adapter/slack");
|
||||
adapters.slack = createSlackAdapter();
|
||||
}
|
||||
|
||||
if (Object.keys(adapters).length === 0) {
|
||||
throw new Error(
|
||||
"No adapters enabled. Set at least one of: " +
|
||||
"DISCORD_BOT_TOKEN, TELEGRAM_BOT_TOKEN, SLACK_BOT_TOKEN"
|
||||
);
|
||||
}
|
||||
|
||||
const bot = new Chat<typeof adapters, BotThreadState>({
|
||||
userName: "copilot",
|
||||
adapters,
|
||||
state: stateAdapter,
|
||||
streamingUpdateIntervalMs: 500,
|
||||
fallbackStreamingPlaceholderText: "Thinking...",
|
||||
});
|
||||
|
||||
// ── New mention (first message in a thread) ──────────────────────
|
||||
|
||||
bot.onNewMention(async (thread, message) => {
|
||||
const adapterName = getAdapterName(thread);
|
||||
const platformUserId = message.author.userId;
|
||||
|
||||
console.log(
|
||||
`[bot] New mention from ${adapterName}:${platformUserId} in ${thread.id}`
|
||||
);
|
||||
|
||||
// Check if user is linked
|
||||
const resolved = await api.resolve(adapterName, platformUserId);
|
||||
|
||||
if (!resolved.linked) {
|
||||
await handleUnlinkedUser(thread, message, adapterName, api);
|
||||
return;
|
||||
}
|
||||
|
||||
// User is linked — subscribe and handle the message
|
||||
await thread.subscribe();
|
||||
await thread.setState({ userId: resolved.user_id });
|
||||
|
||||
await handleCoPilotMessage(thread, message.text, resolved.user_id!, api);
|
||||
});
|
||||
|
||||
// ── Subscribed messages (follow-ups in a thread) ─────────────────
|
||||
|
||||
bot.onSubscribedMessage(async (thread, message) => {
|
||||
const state = await thread.state;
|
||||
|
||||
if (!state?.userId) {
|
||||
// Somehow lost state — re-resolve
|
||||
const adapterName = getAdapterName(thread);
|
||||
const resolved = await api.resolve(adapterName, message.author.userId);
|
||||
|
||||
if (!resolved.linked) {
|
||||
await handleUnlinkedUser(thread, message, adapterName, api);
|
||||
return;
|
||||
}
|
||||
|
||||
await thread.setState({ userId: resolved.user_id });
|
||||
await handleCoPilotMessage(
|
||||
thread,
|
||||
message.text,
|
||||
resolved.user_id!,
|
||||
api
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await handleCoPilotMessage(thread, message.text, state.userId, api);
|
||||
});
|
||||
|
||||
return bot;
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get the adapter/platform name from a thread.
|
||||
* Thread ID format is "adapter:channel:thread".
|
||||
*/
|
||||
function getAdapterName(thread: BotThread): string {
|
||||
const parts = thread.id.split(":");
|
||||
return parts[0] ?? "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle an unlinked user — create a link token and send them a prompt.
|
||||
*/
|
||||
async function handleUnlinkedUser(
|
||||
thread: BotThread,
|
||||
message: Message,
|
||||
platform: string,
|
||||
api: PlatformAPI
|
||||
) {
|
||||
console.log(
|
||||
`[bot] Unlinked user ${platform}:${message.author.userId}, sending link prompt`
|
||||
);
|
||||
|
||||
try {
|
||||
const linkResult = await api.createLinkToken({
|
||||
platform,
|
||||
platformUserId: message.author.userId,
|
||||
platformUsername: message.author.fullName ?? message.author.userName,
|
||||
});
|
||||
|
||||
await thread.post(
|
||||
`👋 To use CoPilot, link your AutoGPT account first.\n\n` +
|
||||
`🔗 **Link your account:** ${linkResult.link_url}\n\n` +
|
||||
`_This link expires in 30 minutes._`
|
||||
);
|
||||
|
||||
// Store the pending token so we could poll later if needed
|
||||
await thread.setState({ pendingLinkToken: linkResult.token });
|
||||
} catch (err: unknown) {
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
if (errMsg.includes("409")) {
|
||||
// Already linked (race condition) — retry resolve
|
||||
const resolved = await api.resolve(platform, message.author.userId);
|
||||
if (resolved.linked) {
|
||||
await thread.subscribe();
|
||||
await thread.setState({ userId: resolved.user_id });
|
||||
await handleCoPilotMessage(
|
||||
thread,
|
||||
message.text,
|
||||
resolved.user_id!,
|
||||
api
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
console.error("[bot] Failed to create link token:", err);
|
||||
await thread.post(
|
||||
"Sorry, I couldn't set up account linking right now. Please try again later."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Forward a message to CoPilot and stream the response back.
|
||||
*/
|
||||
async function handleCoPilotMessage(
|
||||
thread: BotThread,
|
||||
text: string,
|
||||
userId: string,
|
||||
api: PlatformAPI
|
||||
) {
|
||||
const state = await thread.state;
|
||||
let sessionId = state?.sessionId;
|
||||
|
||||
console.log(
|
||||
`[bot] Message from user ${userId.slice(-8)}: ${text.slice(0, 100)}`
|
||||
);
|
||||
|
||||
await thread.startTyping();
|
||||
|
||||
try {
|
||||
// Create a session if we don't have one
|
||||
if (!sessionId) {
|
||||
sessionId = await api.createChatSession(userId);
|
||||
await thread.setState({ ...state, sessionId });
|
||||
console.log(`[bot] Created session ${sessionId} for user ${userId.slice(-8)}`);
|
||||
}
|
||||
|
||||
// Stream CoPilot response — collect chunks, then post.
|
||||
// We collect first because thread.post() with an empty stream
|
||||
// causes Discord "Cannot send an empty message" errors.
|
||||
const stream = api.streamChat(userId, text, sessionId);
|
||||
let response = "";
|
||||
for await (const chunk of stream) {
|
||||
response += chunk;
|
||||
}
|
||||
|
||||
if (response.trim()) {
|
||||
await thread.post(response);
|
||||
} else {
|
||||
await thread.post(
|
||||
"I processed your message but didn't generate a response. Please try again."
|
||||
);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
console.error(`[bot] CoPilot error for user ${userId.slice(-8)}:`, errMsg);
|
||||
await thread.post(
|
||||
"Sorry, I ran into an issue processing your message. Please try again."
|
||||
);
|
||||
}
|
||||
}
|
||||
54
autogpt_platform/copilot-bot/src/config.ts
Normal file
54
autogpt_platform/copilot-bot/src/config.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
import "dotenv/config";
|
||||
|
||||
export interface Config {
|
||||
/** AutoGPT platform API base URL */
|
||||
autogptApiUrl: string;
|
||||
|
||||
/** Whether each adapter is enabled (based on env vars being set) */
|
||||
discord: boolean;
|
||||
telegram: boolean;
|
||||
slack: boolean;
|
||||
|
||||
/** Use Redis for state (production) or in-memory (dev) */
|
||||
redisUrl?: string;
|
||||
}
|
||||
|
||||
export function loadConfig(): Config {
|
||||
const isProduction = process.env.NODE_ENV === "production";
|
||||
|
||||
if (isProduction) {
|
||||
requireEnv("PLATFORM_BOT_API_KEY");
|
||||
|
||||
const hasAdapter =
|
||||
!!process.env.DISCORD_BOT_TOKEN ||
|
||||
!!process.env.TELEGRAM_BOT_TOKEN ||
|
||||
!!process.env.SLACK_BOT_TOKEN;
|
||||
|
||||
if (!hasAdapter) {
|
||||
throw new Error(
|
||||
"Production requires at least one adapter token: " +
|
||||
"DISCORD_BOT_TOKEN, TELEGRAM_BOT_TOKEN, or SLACK_BOT_TOKEN"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
autogptApiUrl: env("AUTOGPT_API_URL", "http://localhost:8006"),
|
||||
discord: !!process.env.DISCORD_BOT_TOKEN,
|
||||
telegram: !!process.env.TELEGRAM_BOT_TOKEN,
|
||||
slack: !!process.env.SLACK_BOT_TOKEN,
|
||||
redisUrl: process.env.REDIS_URL,
|
||||
};
|
||||
}
|
||||
|
||||
function env(key: string, fallback: string): string {
|
||||
return process.env[key] ?? fallback;
|
||||
}
|
||||
|
||||
function requireEnv(key: string): string {
|
||||
const value = process.env[key];
|
||||
if (!value) {
|
||||
throw new Error(`Missing required environment variable: ${key}`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
171
autogpt_platform/copilot-bot/src/index.ts
Normal file
171
autogpt_platform/copilot-bot/src/index.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* CoPilot Bot — Entry point (standalone / long-running).
|
||||
*
|
||||
* Starts an HTTP server for webhook handling and connects to
|
||||
* the Discord Gateway for receiving messages.
|
||||
*/
|
||||
|
||||
// Load .env BEFORE any other imports so env vars are available
|
||||
// when Chat SDK adapters auto-detect credentials at import time.
|
||||
import { config } from "dotenv";
|
||||
config();
|
||||
|
||||
import { loadConfig } from "./config.js";
|
||||
import { createBot } from "./bot.js";
|
||||
|
||||
const PORT = parseInt(process.env.PORT ?? "3001", 10);
|
||||
|
||||
async function main() {
|
||||
console.log("🤖 CoPilot Bot starting...\n");
|
||||
|
||||
const config = loadConfig();
|
||||
|
||||
// Log which adapters are enabled
|
||||
const enabled = [
|
||||
config.discord && "Discord",
|
||||
config.telegram && "Telegram",
|
||||
config.slack && "Slack",
|
||||
].filter(Boolean);
|
||||
|
||||
console.log(`📡 Adapters: ${enabled.join(", ") || "none"}`);
|
||||
console.log(`🔗 API: ${config.autogptApiUrl}`);
|
||||
console.log(`💾 State: ${config.redisUrl ? "Redis" : "In-memory"}`);
|
||||
console.log(`🌐 Port: ${PORT}\n`);
|
||||
|
||||
// Create state adapter
|
||||
let stateAdapter;
|
||||
if (config.redisUrl) {
|
||||
const { createRedisState } = await import("@chat-adapter/state-redis");
|
||||
stateAdapter = createRedisState({ url: config.redisUrl });
|
||||
} else {
|
||||
const { createMemoryState } = await import("@chat-adapter/state-memory");
|
||||
stateAdapter = createMemoryState();
|
||||
}
|
||||
|
||||
// Create the bot
|
||||
const bot = await createBot(config, stateAdapter);
|
||||
|
||||
// Start HTTP server for webhooks
|
||||
await startNodeServer(bot, PORT);
|
||||
|
||||
// Start Discord Gateway if enabled
|
||||
if (config.discord) {
|
||||
await bot.initialize();
|
||||
const discord = bot.getAdapter("discord") as any;
|
||||
|
||||
if (discord?.startGatewayListener) {
|
||||
const webhookUrl = `http://localhost:${PORT}/api/webhooks/discord`;
|
||||
console.log(`🔌 Starting Discord Gateway → ${webhookUrl}`);
|
||||
|
||||
// Run gateway in background, restart on disconnect
|
||||
const runGateway = async () => {
|
||||
while (true) {
|
||||
try {
|
||||
// Track background tasks so the listener stays alive
|
||||
const pendingTasks: Promise<unknown>[] = [];
|
||||
const waitUntil = (task: Promise<unknown>) => {
|
||||
pendingTasks.push(task);
|
||||
};
|
||||
|
||||
const response = await discord.startGatewayListener(
|
||||
{ waitUntil },
|
||||
10 * 60 * 1000, // 10 minutes
|
||||
undefined,
|
||||
webhookUrl
|
||||
);
|
||||
|
||||
// Wait for any background tasks to complete
|
||||
if (pendingTasks.length > 0) {
|
||||
await Promise.allSettled(pendingTasks);
|
||||
}
|
||||
|
||||
console.log("[gateway] Listener ended, restarting...");
|
||||
} catch (err) {
|
||||
console.error("[gateway] Error, restarting in 5s:", err);
|
||||
await new Promise((r) => setTimeout(r, 5000));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Don't await — run in background
|
||||
runGateway();
|
||||
}
|
||||
}
|
||||
|
||||
console.log("\n✅ CoPilot Bot ready.\n");
|
||||
|
||||
// Graceful shutdown
|
||||
process.on("SIGINT", () => {
|
||||
console.log("\n🛑 Shutting down...");
|
||||
process.exit(0);
|
||||
});
|
||||
process.on("SIGTERM", () => {
|
||||
console.log("\n🛑 Shutting down...");
|
||||
process.exit(0);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a simple HTTP server using Node's built-in http module.
|
||||
* Routes webhook requests to the Chat SDK bot.
|
||||
*/
|
||||
async function startNodeServer(bot: any, port: number) {
|
||||
const { createServer } = await import("http");
|
||||
|
||||
const server = createServer(async (req, res) => {
|
||||
const url = new URL(req.url ?? "/", `http://localhost:${port}`);
|
||||
|
||||
// Collect body
|
||||
const chunks: Buffer[] = [];
|
||||
for await (const chunk of req) {
|
||||
chunks.push(chunk as Buffer);
|
||||
}
|
||||
const body = Buffer.concat(chunks);
|
||||
|
||||
// Build a standard Request object for Chat SDK
|
||||
const headers = new Headers();
|
||||
for (const [key, value] of Object.entries(req.headers)) {
|
||||
if (value) headers.set(key, Array.isArray(value) ? value[0] : value);
|
||||
}
|
||||
|
||||
const request = new Request(url.toString(), {
|
||||
method: req.method ?? "POST",
|
||||
headers,
|
||||
body: req.method !== "GET" && req.method !== "HEAD" ? body : undefined,
|
||||
});
|
||||
|
||||
// Route to the correct adapter webhook
|
||||
// URL: /api/webhooks/{platform}
|
||||
const parts = url.pathname.split("/");
|
||||
const platform = parts[parts.length - 1];
|
||||
const handler = platform ? (bot.webhooks as any)[platform] : undefined;
|
||||
|
||||
if (!handler) {
|
||||
res.writeHead(404);
|
||||
res.end("Not found");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response: Response = await handler(request);
|
||||
res.writeHead(response.status, Object.fromEntries(response.headers));
|
||||
const responseBody = await response.arrayBuffer();
|
||||
res.end(Buffer.from(responseBody));
|
||||
} catch (err) {
|
||||
console.error(`[http] Error handling ${url.pathname}:`, err);
|
||||
res.writeHead(500);
|
||||
res.end("Internal error");
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(port, () => {
|
||||
console.log(`🌐 HTTP server listening on http://localhost:${port}`);
|
||||
});
|
||||
|
||||
return server;
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Fatal error:", err);
|
||||
process.exit(1);
|
||||
});
|
||||
219
autogpt_platform/copilot-bot/src/platform-api.ts
Normal file
219
autogpt_platform/copilot-bot/src/platform-api.ts
Normal file
@@ -0,0 +1,219 @@
|
||||
/**
|
||||
* Client for the AutoGPT Platform Linking & Chat APIs.
|
||||
*
|
||||
* Handles:
|
||||
* - Resolving platform users → AutoGPT accounts
|
||||
* - Creating link tokens for unlinked users
|
||||
* - Checking link token status
|
||||
* - Creating chat sessions and streaming messages
|
||||
*/
|
||||
|
||||
export interface ResolveResult {
|
||||
linked: boolean;
|
||||
user_id?: string;
|
||||
platform_username?: string;
|
||||
}
|
||||
|
||||
export interface LinkTokenResult {
|
||||
token: string;
|
||||
expires_at: string;
|
||||
link_url: string;
|
||||
}
|
||||
|
||||
export interface LinkTokenStatus {
|
||||
status: "pending" | "linked" | "expired";
|
||||
user_id?: string;
|
||||
}
|
||||
|
||||
export class PlatformAPI {
|
||||
constructor(private baseUrl: string) {}
|
||||
|
||||
/**
|
||||
* Check if a platform user is linked to an AutoGPT account.
|
||||
*/
|
||||
async resolve(
|
||||
platform: string,
|
||||
platformUserId: string
|
||||
): Promise<ResolveResult> {
|
||||
const res = await fetch(`${this.baseUrl}/api/platform-linking/resolve`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...this.botHeaders(),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
platform: platform.toUpperCase(),
|
||||
platform_user_id: platformUserId,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`Platform resolve failed: ${res.status} ${await res.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a link token for an unlinked platform user.
|
||||
*/
|
||||
async createLinkToken(params: {
|
||||
platform: string;
|
||||
platformUserId: string;
|
||||
platformUsername?: string;
|
||||
channelId?: string;
|
||||
}): Promise<LinkTokenResult> {
|
||||
const res = await fetch(`${this.baseUrl}/api/platform-linking/tokens`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...this.botHeaders(),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
platform: params.platform.toUpperCase(),
|
||||
platform_user_id: params.platformUserId,
|
||||
platform_username: params.platformUsername,
|
||||
channel_id: params.channelId,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`Create link token failed: ${res.status} ${await res.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a link token has been consumed.
|
||||
*/
|
||||
async getLinkTokenStatus(token: string): Promise<LinkTokenStatus> {
|
||||
const res = await fetch(
|
||||
`${this.baseUrl}/api/platform-linking/tokens/${token}/status`,
|
||||
{ headers: this.botHeaders() }
|
||||
);
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`Link token status failed: ${res.status} ${await res.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new CoPilot chat session for a linked user.
|
||||
* Uses the bot chat proxy (no user JWT needed).
|
||||
*/
|
||||
async createChatSession(userId: string): Promise<string> {
|
||||
const res = await fetch(
|
||||
`${this.baseUrl}/api/platform-linking/chat/session`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...this.botHeaders(),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
user_id: userId,
|
||||
message: "session_init",
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`Create chat session failed: ${res.status} ${await res.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
return data.session_id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat message to CoPilot on behalf of a linked user.
|
||||
* Uses the bot chat proxy — authenticated via bot API key.
|
||||
* Yields text chunks from the SSE stream.
|
||||
*/
|
||||
async *streamChat(
|
||||
userId: string,
|
||||
message: string,
|
||||
sessionId?: string
|
||||
): AsyncGenerator<string> {
|
||||
const res = await fetch(
|
||||
`${this.baseUrl}/api/platform-linking/chat/stream`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
...this.botHeaders(),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
user_id: userId,
|
||||
message,
|
||||
session_id: sessionId,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
`Stream chat failed: ${res.status} ${await res.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
if (!res.body) {
|
||||
throw new Error("No response body for SSE stream");
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
const decoder = new TextDecoder();
|
||||
const reader = res.body.getReader();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data === "[DONE]") return;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
// Backend sends: text-delta (streaming chunks), text-start/text-end,
|
||||
// start/finish (lifecycle), start-step/finish-step, error, etc.
|
||||
if (parsed.type === "text-delta" && parsed.delta) {
|
||||
yield parsed.delta;
|
||||
} else if (parsed.type === "error" && parsed.content) {
|
||||
yield `Error: ${parsed.content}`;
|
||||
}
|
||||
// Ignore start/finish/step lifecycle events — they carry no text
|
||||
} catch {
|
||||
// Non-JSON data line — skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private botHeaders(): Record<string, string> {
|
||||
const key = process.env.PLATFORM_BOT_API_KEY;
|
||||
if (key) {
|
||||
return { "X-Bot-API-Key": key };
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
19
autogpt_platform/copilot-bot/tsconfig.json
Normal file
19
autogpt_platform/copilot-bot/tsconfig.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ES2022",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src",
|
||||
"declaration": true,
|
||||
"resolveJsonModule": true,
|
||||
"jsx": "react-jsx",
|
||||
"jsxImportSource": "chat"
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
@@ -7,6 +7,7 @@ const config: StorybookConfig = {
|
||||
"../src/components/atoms/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/renderers/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
],
|
||||
addons: [
|
||||
"@storybook/addon-a11y",
|
||||
|
||||
@@ -25,7 +25,11 @@ Sentry.init({
|
||||
// Suppress cross-origin stylesheet errors from Sentry Replay (rrweb)
|
||||
// serializing DOM snapshots with cross-origin stylesheets
|
||||
// (e.g., from browser extensions or CDN-loaded CSS)
|
||||
ignoreErrors: [/Not allowed to access cross-origin stylesheet/],
|
||||
ignoreErrors: [
|
||||
/Not allowed to access cross-origin stylesheet/,
|
||||
// Sentry SDK internal issue on some mobile browsers
|
||||
/Error invoking postEvent: Method not found/,
|
||||
],
|
||||
|
||||
// Add optional integrations for additional features
|
||||
integrations: [
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { CheckCircle, LinkBreak, Spinner } from "@phosphor-icons/react";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
/** Platform display names */
|
||||
const PLATFORM_NAMES: Record<string, string> = {
|
||||
DISCORD: "Discord",
|
||||
TELEGRAM: "Telegram",
|
||||
SLACK: "Slack",
|
||||
TEAMS: "Teams",
|
||||
WHATSAPP: "WhatsApp",
|
||||
GITHUB: "GitHub",
|
||||
LINEAR: "Linear",
|
||||
};
|
||||
|
||||
type LinkState =
|
||||
| { status: "loading" }
|
||||
| { status: "not-authenticated" }
|
||||
| { status: "ready" }
|
||||
| { status: "linking" }
|
||||
| { status: "success"; platform: string }
|
||||
| { status: "error"; message: string };
|
||||
|
||||
export default function PlatformLinkPage() {
|
||||
const params = useParams();
|
||||
const token = params.token as string;
|
||||
const { user, isUserLoading } = useSupabase();
|
||||
|
||||
const [state, setState] = useState<LinkState>({ status: "loading" });
|
||||
|
||||
// Determine initial state based on auth.
|
||||
// Token validity is checked server-side on confirm — no need for a
|
||||
// separate status call (which requires bot API key anyway).
|
||||
useEffect(() => {
|
||||
if (!token) return;
|
||||
// Wait for Supabase to finish loading before deciding auth state
|
||||
if (isUserLoading) return;
|
||||
|
||||
if (!user) {
|
||||
setState({ status: "not-authenticated" });
|
||||
} else {
|
||||
setState({ status: "ready" });
|
||||
}
|
||||
}, [token, user, isUserLoading]);
|
||||
|
||||
async function handleLink() {
|
||||
setState({ status: "linking" });
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), 30_000);
|
||||
|
||||
try {
|
||||
// The proxy injects auth from the server-side Supabase session cookie,
|
||||
// so we don't need to send Authorization ourselves.
|
||||
const res = await fetch(
|
||||
`/api/proxy/api/platform-linking/tokens/${token}/confirm`,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({}),
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: controller.signal,
|
||||
},
|
||||
);
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json().catch(() => null);
|
||||
const message =
|
||||
data?.detail ??
|
||||
"Failed to link your account. The link may have expired.";
|
||||
setState({ status: "error", message });
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
const platformName = PLATFORM_NAMES[data.platform] ?? data.platform;
|
||||
setState({ status: "success", platform: platformName });
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") {
|
||||
setState({
|
||||
status: "error",
|
||||
message:
|
||||
"Request timed out. Please go back to your chat and try again.",
|
||||
});
|
||||
} else {
|
||||
setState({
|
||||
status: "error",
|
||||
message: "Something went wrong. Please try again.",
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full min-h-[85vh] flex-col items-center justify-center py-10">
|
||||
{state.status === "loading" && <LoadingView />}
|
||||
|
||||
{state.status === "not-authenticated" && (
|
||||
<NotAuthenticatedView token={token} />
|
||||
)}
|
||||
|
||||
{state.status === "ready" && <ReadyView onLink={handleLink} />}
|
||||
|
||||
{state.status === "linking" && <LinkingView />}
|
||||
|
||||
{state.status === "success" && <SuccessView platform={state.platform} />}
|
||||
|
||||
{state.status === "error" && <ErrorView message={state.message} />}
|
||||
|
||||
<div className="mt-8 text-center text-xs text-muted-foreground">
|
||||
<p>Powered by AutoGPT Platform</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function LoadingView() {
|
||||
return (
|
||||
<AuthCard title="Link your account">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<Spinner size={48} className="animate-spin text-primary" />
|
||||
<Text variant="body-medium" className="text-muted-foreground">
|
||||
Verifying link...
|
||||
</Text>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
|
||||
function NotAuthenticatedView({ token }: { token: string }) {
|
||||
const loginUrl = `/login?next=${encodeURIComponent(`/link/${token}`)}`;
|
||||
|
||||
return (
|
||||
<AuthCard title="Link your account">
|
||||
<div className="flex w-full flex-col items-center gap-6">
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="text-center text-muted-foreground"
|
||||
>
|
||||
Sign in to your AutoGPT account to link it with your chat platform.
|
||||
</Text>
|
||||
<Button as="NextLink" href={loginUrl} className="w-full">
|
||||
Sign in to continue
|
||||
</Button>
|
||||
<AuthCard.BottomText
|
||||
text="Don't have an account?"
|
||||
link={{ text: "Sign up", href: `/signup?next=/link/${token}` }}
|
||||
/>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
|
||||
function ReadyView({ onLink }: { onLink: () => void }) {
|
||||
return (
|
||||
<AuthCard title="Link your account">
|
||||
<div className="flex w-full flex-col items-center gap-6">
|
||||
<div className="rounded-xl bg-slate-50 p-6 text-center">
|
||||
<Text variant="body-medium" className="text-muted-foreground">
|
||||
Connect your chat platform account to AutoGPT to use CoPilot.
|
||||
</Text>
|
||||
</div>
|
||||
<Button onClick={onLink} className="w-full">
|
||||
Link account
|
||||
</Button>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
|
||||
function LinkingView() {
|
||||
return (
|
||||
<AuthCard title="Linking...">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<Spinner size={48} className="animate-spin text-primary" />
|
||||
<Text variant="body-medium" className="text-muted-foreground">
|
||||
Connecting your accounts...
|
||||
</Text>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
|
||||
function SuccessView({ platform }: { platform: string }) {
|
||||
return (
|
||||
<AuthCard title="Account linked!">
|
||||
<div className="flex w-full flex-col items-center gap-6">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-green-100">
|
||||
<CheckCircle size={40} weight="fill" className="text-green-600" />
|
||||
</div>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="text-center text-muted-foreground"
|
||||
>
|
||||
Your <strong>{platform}</strong> account is now linked to AutoGPT.
|
||||
<br />
|
||||
You can close this page and go back to your chat.
|
||||
</Text>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
|
||||
function ErrorView({ message }: { message: string }) {
|
||||
return (
|
||||
<AuthCard title="Link failed">
|
||||
<div className="flex w-full flex-col items-center gap-6">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-red-100">
|
||||
<LinkBreak size={40} weight="bold" className="text-red-600" />
|
||||
</div>
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="text-center text-muted-foreground"
|
||||
>
|
||||
{message}
|
||||
</Text>
|
||||
<Text variant="small" className="text-center text-muted-foreground">
|
||||
Go back to your chat and ask the bot for a new link.
|
||||
</Text>
|
||||
</div>
|
||||
</AuthCard>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { MagnifyingGlass } from "@phosphor-icons/react";
|
||||
|
||||
export interface AdminUserSearchProps {
|
||||
/** Current search query value (controlled). Falls back to internal state if omitted. */
|
||||
value?: string;
|
||||
/** Called when the input text changes */
|
||||
onChange?: (value: string) => void;
|
||||
/** Called when the user presses Enter or clicks the search button */
|
||||
onSearch: (query: string) => void;
|
||||
/** Placeholder text for the input */
|
||||
placeholder?: string;
|
||||
/** Disables the input and button while a search is in progress */
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared admin user search input.
|
||||
* Supports searching users by name, email, or partial/fuzzy text.
|
||||
* Can be used as controlled (value + onChange) or uncontrolled (internal state).
|
||||
*/
|
||||
export function AdminUserSearch({
|
||||
value: controlledValue,
|
||||
onChange,
|
||||
onSearch,
|
||||
placeholder = "Search users by Name or Email...",
|
||||
isLoading = false,
|
||||
}: AdminUserSearchProps) {
|
||||
const [internalValue, setInternalValue] = useState("");
|
||||
|
||||
const isControlled = controlledValue !== undefined;
|
||||
const currentValue = isControlled ? controlledValue : internalValue;
|
||||
|
||||
function handleChange(newValue: string) {
|
||||
if (isControlled) {
|
||||
onChange?.(newValue);
|
||||
} else {
|
||||
setInternalValue(newValue);
|
||||
}
|
||||
}
|
||||
|
||||
function handleSearch() {
|
||||
onSearch(currentValue.trim());
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex w-full items-center gap-2">
|
||||
<Input
|
||||
placeholder={placeholder}
|
||||
aria-label={placeholder}
|
||||
value={currentValue}
|
||||
onChange={(e) => handleChange(e.target.value)}
|
||||
onKeyDown={(e) => e.key === "Enter" && handleSearch()}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleSearch}
|
||||
disabled={isLoading || !currentValue.trim()}
|
||||
loading={isLoading}
|
||||
>
|
||||
{isLoading ? "Searching..." : <MagnifyingGlass size={16} />}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
"use client";
|
||||
|
||||
export function formatTokens(tokens: number): string {
|
||||
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
|
||||
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(0)}K`;
|
||||
return tokens.toString();
|
||||
}
|
||||
|
||||
export function UsageBar({ used, limit }: { used: number; limit: number }) {
|
||||
if (limit === 0) {
|
||||
return <span className="text-sm text-gray-500">Unlimited</span>;
|
||||
}
|
||||
const pct = Math.min(Math.max(0, (used / limit) * 100), 100);
|
||||
const color =
|
||||
pct >= 90 ? "bg-red-500" : pct >= 70 ? "bg-yellow-500" : "bg-green-500";
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
<div className="flex justify-between text-sm">
|
||||
<span>{formatTokens(used)} used</span>
|
||||
<span>{formatTokens(limit)} limit</span>
|
||||
</div>
|
||||
<div className="h-2 w-full rounded-full bg-gray-200">
|
||||
<div
|
||||
className={`h-2 rounded-full ${color}`}
|
||||
style={{ width: `${pct}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className="text-right text-xs text-gray-500">
|
||||
{pct.toFixed(1)}% used
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Gauge } from "@phosphor-icons/react/dist/ssr";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -21,6 +22,11 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/impersonation",
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Rate Limits",
|
||||
href: "/admin/rate-limits",
|
||||
icon: <Gauge className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
|
||||
import { UsageBar } from "../../components/UsageBar";
|
||||
|
||||
interface Props {
|
||||
data: UserRateLimitResponse;
|
||||
onReset: (resetWeekly: boolean) => Promise<void>;
|
||||
/** Override the outer container classes (default: bordered card). */
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function RateLimitDisplay({ data, onReset, className }: Props) {
|
||||
const [isResetting, setIsResetting] = useState(false);
|
||||
const [resetWeekly, setResetWeekly] = useState(false);
|
||||
|
||||
async function handleReset() {
|
||||
const msg = resetWeekly
|
||||
? "Reset both daily and weekly usage counters to zero?"
|
||||
: "Reset daily usage counter to zero?";
|
||||
if (!window.confirm(msg)) return;
|
||||
|
||||
setIsResetting(true);
|
||||
try {
|
||||
await onReset(resetWeekly);
|
||||
} finally {
|
||||
setIsResetting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const nothingToReset = resetWeekly
|
||||
? data.daily_tokens_used === 0 && data.weekly_tokens_used === 0
|
||||
: data.daily_tokens_used === 0;
|
||||
|
||||
return (
|
||||
<div className={className ?? "rounded-md border bg-white p-6"}>
|
||||
<h2 className="mb-1 text-lg font-semibold">
|
||||
Rate Limits for {data.user_email ?? data.user_id}
|
||||
</h2>
|
||||
{data.user_email && (
|
||||
<p className="mb-4 text-xs text-gray-500">User ID: {data.user_id}</p>
|
||||
)}
|
||||
{!data.user_email && <div className="mb-4" />}
|
||||
|
||||
<div className="grid grid-cols-2 gap-6">
|
||||
<div className="space-y-2">
|
||||
<h3 className="text-sm font-medium text-gray-700">Daily Usage</h3>
|
||||
<UsageBar
|
||||
used={data.daily_tokens_used}
|
||||
limit={data.daily_token_limit}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<h3 className="text-sm font-medium text-gray-700">Weekly Usage</h3>
|
||||
<UsageBar
|
||||
used={data.weekly_tokens_used}
|
||||
limit={data.weekly_token_limit}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-6 flex items-center gap-3 border-t pt-4">
|
||||
<select
|
||||
aria-label="Reset scope"
|
||||
value={resetWeekly ? "both" : "daily"}
|
||||
onChange={(e) => setResetWeekly(e.target.value === "both")}
|
||||
className="rounded-md border bg-white px-3 py-1.5 text-sm"
|
||||
disabled={isResetting}
|
||||
>
|
||||
<option value="daily">Reset daily only</option>
|
||||
<option value="both">Reset daily + weekly</option>
|
||||
</select>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handleReset}
|
||||
disabled={isResetting || nothingToReset}
|
||||
>
|
||||
{isResetting ? "Resetting..." : "Reset Usage"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
"use client";
|
||||
|
||||
import { AdminUserSearch } from "../../components/AdminUserSearch";
|
||||
import { RateLimitDisplay } from "./RateLimitDisplay";
|
||||
import { useRateLimitManager } from "./useRateLimitManager";
|
||||
|
||||
export function RateLimitManager() {
|
||||
const {
|
||||
isSearching,
|
||||
isLoadingRateLimit,
|
||||
searchResults,
|
||||
selectedUser,
|
||||
rateLimitData,
|
||||
handleSearch,
|
||||
handleSelectUser,
|
||||
handleReset,
|
||||
} = useRateLimitManager();
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="rounded-md border bg-white p-6">
|
||||
<label className="mb-2 block text-sm font-medium">Search User</label>
|
||||
<AdminUserSearch
|
||||
onSearch={handleSearch}
|
||||
placeholder="Search by name, email, or user ID..."
|
||||
isLoading={isSearching}
|
||||
/>
|
||||
<p className="mt-1.5 text-xs text-gray-500">
|
||||
Exact email or user ID does a direct lookup. Partial text searches
|
||||
user history.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* User selection list -- always require explicit selection */}
|
||||
{searchResults.length >= 1 && !selectedUser && (
|
||||
<div className="rounded-md border bg-white p-4">
|
||||
<h3 className="mb-2 text-sm font-medium text-gray-700">
|
||||
Select a user ({searchResults.length}{" "}
|
||||
{searchResults.length === 1 ? "result" : "results"})
|
||||
</h3>
|
||||
<ul className="divide-y">
|
||||
{searchResults.map((user) => (
|
||||
<li key={user.user_id}>
|
||||
<button
|
||||
className="w-full px-2 py-2 text-left text-sm hover:bg-gray-100"
|
||||
onClick={() => handleSelectUser(user)}
|
||||
>
|
||||
<span className="font-medium">{user.user_email}</span>
|
||||
<span className="ml-2 text-xs text-gray-500">
|
||||
{user.user_id}
|
||||
</span>
|
||||
</button>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show selected user */}
|
||||
{selectedUser && searchResults.length >= 1 && (
|
||||
<div className="rounded-md border border-blue-200 bg-blue-50 px-4 py-2 text-sm">
|
||||
Selected:{" "}
|
||||
<span className="font-medium">{selectedUser.user_email}</span>
|
||||
<span className="ml-2 text-xs text-gray-500">
|
||||
{selectedUser.user_id}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isLoadingRateLimit && (
|
||||
<div className="py-4 text-center text-sm text-gray-500">
|
||||
Loading rate limits...
|
||||
</div>
|
||||
)}
|
||||
|
||||
{rateLimitData && (
|
||||
<RateLimitDisplay data={rateLimitData} onReset={handleReset} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
|
||||
import {
|
||||
getV2GetUserRateLimit,
|
||||
getV2GetAllUsersHistory,
|
||||
postV2ResetUserRateLimitUsage,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
export interface UserOption {
|
||||
user_id: string;
|
||||
user_email: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true when the input looks like a complete email address.
|
||||
* Used to decide whether to call the direct email lookup endpoint
|
||||
* vs. the broader user-history search.
|
||||
*/
|
||||
function looksLikeEmail(input: string): boolean {
|
||||
return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true when the input looks like a UUID (user ID).
|
||||
*/
|
||||
function looksLikeUuid(input: string): boolean {
|
||||
return /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test(
|
||||
input,
|
||||
);
|
||||
}
|
||||
|
||||
export function useRateLimitManager() {
|
||||
const { toast } = useToast();
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [isLoadingRateLimit, setIsLoadingRateLimit] = useState(false);
|
||||
const [searchResults, setSearchResults] = useState<UserOption[]>([]);
|
||||
const [selectedUser, setSelectedUser] = useState<UserOption | null>(null);
|
||||
const [rateLimitData, setRateLimitData] =
|
||||
useState<UserRateLimitResponse | null>(null);
|
||||
|
||||
/** Direct lookup by email or user ID via the rate-limit endpoint. */
|
||||
async function handleDirectLookup(trimmed: string) {
|
||||
setIsSearching(true);
|
||||
setSearchResults([]);
|
||||
setSelectedUser(null);
|
||||
setRateLimitData(null);
|
||||
|
||||
try {
|
||||
const params = looksLikeEmail(trimmed)
|
||||
? { email: trimmed }
|
||||
: { user_id: trimmed };
|
||||
const response = await getV2GetUserRateLimit(params);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch rate limit");
|
||||
}
|
||||
setRateLimitData(response.data);
|
||||
setSelectedUser({
|
||||
user_id: response.data.user_id,
|
||||
user_email: response.data.user_email ?? response.data.user_id,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error fetching rate limit:", error);
|
||||
const hint = looksLikeEmail(trimmed)
|
||||
? "No user found with that email address."
|
||||
: "Check the user ID and try again.";
|
||||
toast({
|
||||
title: "Error",
|
||||
description: `Failed to fetch rate limits. ${hint}`,
|
||||
variant: "destructive",
|
||||
});
|
||||
setRateLimitData(null);
|
||||
} finally {
|
||||
setIsSearching(false);
|
||||
}
|
||||
}
|
||||
|
||||
/** Fuzzy name/email search via the spending-history endpoint. */
|
||||
async function handleFuzzySearch(trimmed: string) {
|
||||
setIsSearching(true);
|
||||
setSearchResults([]);
|
||||
setSelectedUser(null);
|
||||
setRateLimitData(null);
|
||||
|
||||
try {
|
||||
const response = await getV2GetAllUsersHistory({
|
||||
search: trimmed,
|
||||
page: 1,
|
||||
page_size: 50,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to search users");
|
||||
}
|
||||
|
||||
// Deduplicate by user_id to get unique users
|
||||
const seen = new Set<string>();
|
||||
const users: UserOption[] = [];
|
||||
for (const tx of response.data.history) {
|
||||
if (!seen.has(tx.user_id)) {
|
||||
seen.add(tx.user_id);
|
||||
users.push({
|
||||
user_id: tx.user_id,
|
||||
user_email: String(tx.user_email ?? tx.user_id),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (users.length === 0) {
|
||||
toast({
|
||||
title: "No results",
|
||||
description: "No users found matching your search.",
|
||||
});
|
||||
}
|
||||
|
||||
// Always show the result list so the user explicitly picks a match.
|
||||
// The history endpoint paginates transactions, not users, so a single
|
||||
// page may not be authoritative -- avoid auto-selecting.
|
||||
setSearchResults(users);
|
||||
} catch (error) {
|
||||
console.error("Error searching users:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to search users.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} finally {
|
||||
setIsSearching(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSearch(query: string) {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) return;
|
||||
|
||||
// Direct lookup when the input is a full email or UUID.
|
||||
// This avoids the spending-history indirection and works even for
|
||||
// users who have no credit transaction history.
|
||||
if (looksLikeEmail(trimmed) || looksLikeUuid(trimmed)) {
|
||||
await handleDirectLookup(trimmed);
|
||||
} else {
|
||||
await handleFuzzySearch(trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchRateLimit(userId: string) {
|
||||
setIsLoadingRateLimit(true);
|
||||
try {
|
||||
const response = await getV2GetUserRateLimit({ user_id: userId });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch rate limit");
|
||||
}
|
||||
setRateLimitData(response.data);
|
||||
} catch (error) {
|
||||
console.error("Error fetching rate limit:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to fetch user rate limit.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setRateLimitData(null);
|
||||
} finally {
|
||||
setIsLoadingRateLimit(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSelectUser(user: UserOption) {
|
||||
setSelectedUser(user);
|
||||
setRateLimitData(null);
|
||||
await fetchRateLimit(user.user_id);
|
||||
}
|
||||
|
||||
async function handleReset(resetWeekly: boolean) {
|
||||
if (!rateLimitData) return;
|
||||
|
||||
try {
|
||||
const response = await postV2ResetUserRateLimitUsage({
|
||||
user_id: rateLimitData.user_id,
|
||||
reset_weekly: resetWeekly,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to reset usage");
|
||||
}
|
||||
setRateLimitData(response.data);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: resetWeekly
|
||||
? "Daily and weekly usage reset to zero."
|
||||
: "Daily usage reset to zero.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error resetting rate limit:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to reset rate limit usage.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isSearching,
|
||||
isLoadingRateLimit,
|
||||
searchResults,
|
||||
selectedUser,
|
||||
rateLimitData,
|
||||
handleSearch,
|
||||
handleSelectUser,
|
||||
handleReset,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { RateLimitManager } from "./components/RateLimitManager";
|
||||
|
||||
function RateLimitsDashboard() {
|
||||
return (
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">User Rate Limits</h1>
|
||||
<p className="text-gray-500">
|
||||
Check and manage CoPilot rate limits per user
|
||||
</p>
|
||||
</div>
|
||||
<RateLimitManager />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function RateLimitsDashboardPage() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedDashboard = await withAdminAccess(RateLimitsDashboard);
|
||||
return <ProtectedDashboard />;
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import { PaginationControls } from "../../../../../components/__legacy__/ui/pagi
|
||||
import { SearchAndFilterAdminSpending } from "./SearchAndFilterAdminSpending";
|
||||
import { getUsersTransactionHistory } from "@/app/(platform)/admin/spending/actions";
|
||||
import { AdminAddMoneyButton } from "./AddMoneyButton";
|
||||
import { RateLimitModal } from "./RateLimitModal";
|
||||
import { CreditTransactionType } from "@/lib/autogpt-server-api";
|
||||
|
||||
export async function AdminUserGrantHistory({
|
||||
@@ -80,10 +81,7 @@ export async function AdminUserGrantHistory({
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<SearchAndFilterAdminSpending
|
||||
initialStatus={initialStatus}
|
||||
initialSearch={initialSearch}
|
||||
/>
|
||||
<SearchAndFilterAdminSpending initialSearch={initialSearch} />
|
||||
|
||||
<div className="rounded-md border bg-white">
|
||||
<Table>
|
||||
@@ -105,7 +103,7 @@ export async function AdminUserGrantHistory({
|
||||
{history.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={8}
|
||||
colSpan={9}
|
||||
className="py-10 text-center text-gray-500"
|
||||
>
|
||||
No transactions found
|
||||
@@ -114,7 +112,7 @@ export async function AdminUserGrantHistory({
|
||||
) : (
|
||||
history.map((transaction) => (
|
||||
<TableRow
|
||||
key={transaction.user_id}
|
||||
key={`${transaction.user_id}-${transaction.transaction_time}`}
|
||||
className="hover:bg-gray-50"
|
||||
>
|
||||
<TableCell className="font-medium">
|
||||
@@ -147,25 +145,29 @@ export async function AdminUserGrantHistory({
|
||||
${transaction.current_balance / 100}
|
||||
</TableCell> */}
|
||||
<TableCell className="text-right">
|
||||
<AdminAddMoneyButton
|
||||
userId={transaction.user_id}
|
||||
userEmail={
|
||||
transaction.user_email ?? "User Email wasn't attached"
|
||||
}
|
||||
currentBalance={transaction.current_balance}
|
||||
defaultAmount={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? -transaction.amount
|
||||
: undefined
|
||||
}
|
||||
defaultComments={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? "Refund for usage"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<RateLimitModal
|
||||
userId={transaction.user_id}
|
||||
userEmail={transaction.user_email ?? ""}
|
||||
/>
|
||||
<AdminAddMoneyButton
|
||||
userId={transaction.user_id}
|
||||
userEmail={transaction.user_email ?? ""}
|
||||
currentBalance={transaction.current_balance}
|
||||
defaultAmount={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? -transaction.amount
|
||||
: undefined
|
||||
}
|
||||
defaultComments={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? "Refund for usage"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
|
||||
import {
|
||||
getV2GetUserRateLimit,
|
||||
postV2ResetUserRateLimitUsage,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { Gauge } from "@phosphor-icons/react";
|
||||
import { RateLimitDisplay } from "../../rate-limits/components/RateLimitDisplay";
|
||||
|
||||
export function RateLimitModal({
|
||||
userId,
|
||||
userEmail,
|
||||
}: {
|
||||
userId: string;
|
||||
userEmail: string;
|
||||
}) {
|
||||
const { toast } = useToast();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [rateLimitData, setRateLimitData] =
|
||||
useState<UserRateLimitResponse | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
setRateLimitData(null);
|
||||
return;
|
||||
}
|
||||
|
||||
async function fetchRateLimit() {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const response = await getV2GetUserRateLimit({ user_id: userId });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch rate limit");
|
||||
}
|
||||
setRateLimitData(response.data);
|
||||
} catch (error) {
|
||||
console.error("Error fetching rate limit:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to fetch user rate limit.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setRateLimitData(null);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
fetchRateLimit();
|
||||
}, [open, userId, toast]);
|
||||
|
||||
async function handleReset(resetWeekly: boolean) {
|
||||
if (!rateLimitData) return;
|
||||
|
||||
try {
|
||||
const response = await postV2ResetUserRateLimitUsage({
|
||||
user_id: rateLimitData.user_id,
|
||||
reset_weekly: resetWeekly,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to reset usage");
|
||||
}
|
||||
setRateLimitData(response.data);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: resetWeekly
|
||||
? "Daily and weekly usage reset to zero."
|
||||
: "Daily usage reset to zero.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error resetting rate limit:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to reset rate limit usage.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
size="small"
|
||||
variant="outline"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setOpen(true);
|
||||
}}
|
||||
>
|
||||
<Gauge size={16} className="mr-1" />
|
||||
Rate Limits
|
||||
</Button>
|
||||
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
<DialogContent className="sm:max-w-lg">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Rate Limits</DialogTitle>
|
||||
<DialogDescription>
|
||||
CoPilot rate limits for {userEmail || userId}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{isLoading && (
|
||||
<div className="py-8 text-center text-gray-500">
|
||||
Loading rate limits...
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isLoading && rateLimitData && (
|
||||
<RateLimitDisplay
|
||||
data={rateLimitData}
|
||||
onReset={handleReset}
|
||||
className="space-y-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isLoading && !rateLimitData && (
|
||||
<div className="py-8 text-center text-gray-500">
|
||||
No rate limit data available for this user.
|
||||
</div>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user